Merge branch 'arc53:main' into basic-ui

This commit is contained in:
Manish Madan
2025-01-03 18:13:01 +05:30
committed by GitHub
66 changed files with 3033 additions and 2869 deletions

View File

@@ -8,14 +8,14 @@ RUN apt-get update && \
add-apt-repository ppa:deadsnakes/ppa && \
# Install necessary packages and Python
apt-get update && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.11 python3.11-distutils python3.11-venv && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
RUN if [ -f /usr/bin/python3.11 ]; then \
ln -s /usr/bin/python3.11 /usr/bin/python; \
RUN if [ -f /usr/bin/python3.12 ]; then \
ln -s /usr/bin/python3.12 /usr/bin/python; \
else \
echo "Python 3.11 not found"; exit 1; \
echo "Python 3.12 not found"; exit 1; \
fi
# Download and unzip the model
@@ -33,7 +33,7 @@ RUN apt-get remove --purge -y wget unzip && apt-get autoremove -y && rm -rf /var
COPY requirements.txt .
# Setup Python virtual environment
RUN python3.11 -m venv /venv
RUN python3.12 -m venv /venv
# Activate virtual environment and install Python packages
ENV PATH="/venv/bin:$PATH"
@@ -50,8 +50,8 @@ RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
# Install Python
apt-get update && apt-get install -y --no-install-recommends python3.11 && \
ln -s /usr/bin/python3.11 /usr/bin/python && \
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
ln -s /usr/bin/python3.12 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*
# Set working directory

View File

@@ -18,7 +18,7 @@ from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields
from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__)
@@ -37,7 +37,7 @@ api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_NAME == "openai":
gpt_model = "gpt-3.5-turbo"
gpt_model = "gpt-4o-mini"
elif settings.LLM_NAME == "anthropic":
gpt_model = "claude-2"
elif settings.LLM_NAME == "groq":
@@ -324,8 +324,7 @@ class Stream(Resource):
try:
question = data["question"]
history = str(data.get("history", []))
history = str(json.loads(history))
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
@@ -456,7 +455,7 @@ class Answer(Resource):
try:
question = data["question"]
history = data.get("history", [])
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))

View File

@@ -1,14 +1,14 @@
import datetime
import math
import os
import shutil
import uuid
import math
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, jsonify, make_response, request, redirect
from flask_restx import inputs, fields, Namespace, Resource
from flask import Blueprint, jsonify, make_response, redirect, request
from flask_restx import fields, inputs, Namespace, Resource
from werkzeug.utils import secure_filename
from application.api.user.tasks import ingest, ingest_remote
@@ -16,9 +16,10 @@ from application.api.user.tasks import ingest, ingest_remote
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.extensions import api
from application.tools.tool_manager import ToolManager
from application.tts.google_tts import GoogleTTS
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
from application.tts.google_tts import GoogleTTS
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -30,6 +31,7 @@ api_key_collection = db["api_keys"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
user = Blueprint("user", __name__)
user_ns = Namespace("user", description="User related operations", path="/")
@@ -39,6 +41,9 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
tool_config = {}
tool_manager = ToolManager(config=tool_config)
def generate_minute_range(start_date, end_date):
return {
@@ -176,10 +181,12 @@ class SubmitFeedback(Resource):
"FeedbackModel",
{
"question": fields.String(
required=True, description="The user question"
required=False, description="The user question"
),
"answer": fields.String(required=True, description="The AI answer"),
"answer": fields.String(required=False, description="The AI answer"),
"feedback": fields.String(required=True, description="User feedback"),
"question_index":fields.Integer(required=True, description="The question number in that particular conversation"),
"conversation_id":fields.String(required=True, description="id of the particular conversation"),
"api_key": fields.String(description="Optional API key"),
},
)
@@ -189,23 +196,21 @@ class SubmitFeedback(Resource):
)
def post(self):
data = request.get_json()
required_fields = ["question", "answer", "feedback"]
required_fields = [ "feedback","conversation_id","question_index"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
new_doc = {
"question": data["question"],
"answer": data["answer"],
"feedback": data["feedback"],
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if "api_key" in data:
new_doc["api_key"] = data["api_key"]
try:
feedback_collection.insert_one(new_doc)
conversations_collection.update_one(
{"_id": ObjectId(data["conversation_id"]), f"queries.{data['question_index']}": {"$exists": True}},
{
"$set": {
f"queries.{data['question_index']}.feedback": data["feedback"]
}
}
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@@ -1802,3 +1807,295 @@ class TextToSpeech(Resource):
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@user_ns.route("/api/available_tools")
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
doc = tool_instance.__doc__.strip()
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
"actions": tool_instance.get_actions_metadata(),
}
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
@user_ns.route("/api/get_tools")
class GetTools(Resource):
@api.doc(description="Get tools created by a user")
def get(self):
try:
user = "local"
tools = user_tools_collection.find({"user": user})
user_tools = []
for tool in tools:
tool["id"] = str(tool["_id"])
tool.pop("_id")
user_tools.append(tool)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
@user_ns.route("/api/create_tool")
class CreateTool(Resource):
@api.expect(
api.model(
"CreateToolModel",
{
"name": fields.String(required=True, description="Name of the tool"),
"displayName": fields.String(
required=True, description="Display name for the tool"
),
"description": fields.String(
required=True, description="Tool description"
),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.doc(description="Create a new tool")
def post(self):
data = request.get_json()
required_fields = [
"name",
"displayName",
"description",
"actions",
"config",
"status",
]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user = "local"
transformed_actions = []
for action in data["actions"]:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
try:
new_tool = {
"user": user,
"name": data["name"],
"displayName": data["displayName"],
"description": data["description"],
"actions": transformed_actions,
"config": data["config"],
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
new_id = str(resp.inserted_id)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"id": new_id}), 200)
@user_ns.route("/api/update_tool")
class UpdateTool(Resource):
@api.expect(
api.model(
"UpdateToolModel",
{
"id": fields.String(required=True, description="Tool ID"),
"name": fields.String(description="Name of the tool"),
"displayName": fields.String(description="Display name for the tool"),
"description": fields.String(description="Tool description"),
"config": fields.Raw(description="Configuration of the tool"),
"actions": fields.List(
fields.Raw, description="Actions the tool can perform"
),
"status": fields.Boolean(description="Status of the tool"),
},
)
)
@api.doc(description="Update a tool by ID")
def post(self):
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
update_data = {}
if "name" in data:
update_data["name"] = data["name"]
if "displayName" in data:
update_data["displayName"] = data["displayName"]
if "description" in data:
update_data["description"] = data["description"]
if "actions" in data:
update_data["actions"] = data["actions"]
if "config" in data:
update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": "local"},
{"$set": update_data},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_config")
class UpdateToolConfig(Resource):
@api.expect(
api.model(
"UpdateToolConfigModel",
{
"id": fields.String(required=True, description="Tool ID"),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
},
)
)
@api.doc(description="Update the configuration of a tool")
def post(self):
data = request.get_json()
required_fields = ["id", "config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"])},
{"$set": {"config": data["config"]}},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_actions")
class UpdateToolActions(Resource):
@api.expect(
api.model(
"UpdateToolActionsModel",
{
"id": fields.String(required=True, description="Tool ID"),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
},
)
)
@api.doc(description="Update the actions of a tool")
def post(self):
data = request.get_json()
required_fields = ["id", "actions"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"])},
{"$set": {"actions": data["actions"]}},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_status")
class UpdateToolStatus(Resource):
@api.expect(
api.model(
"UpdateToolStatusModel",
{
"id": fields.String(required=True, description="Tool ID"),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.doc(description="Update the status of a tool")
def post(self):
data = request.get_json()
required_fields = ["id", "status"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"])},
{"$set": {"status": data["status"]}},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/delete_tool")
class DeleteTool(Resource):
@api.expect(
api.model(
"DeleteToolModel",
{"id": fields.String(required=True, description="Tool ID")},
)
)
@api.doc(description="Delete a tool by ID")
def post(self):
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])})
if result.deleted_count == 0:
return {"success": False, "message": "Tool not found"}, 404
except Exception as err:
return {"success": False, "error": str(err)}, 400
return {"success": True}, 200

View File

@@ -1,8 +1,10 @@
import redis
import time
import json
import logging
import time
from threading import Lock
import redis
from application.core.settings import settings
from application.utils import get_hash
@@ -11,41 +13,47 @@ logger = logging.getLogger(__name__)
_redis_instance = None
_instance_lock = Lock()
def get_redis_instance():
global _redis_instance
if _redis_instance is None:
with _instance_lock:
if _redis_instance is None:
try:
_redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2)
_redis_instance = redis.Redis.from_url(
settings.CACHE_REDIS_URL, socket_connect_timeout=2
)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
_redis_instance = None
return _redis_instance
def gen_cache_key(*messages, model="docgpt"):
def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(list(messages), sort_keys=True)
combined = f"{model}_{messages_str}"
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)
return cache_key
def gen_cache(func):
def wrapper(self, model, messages, *args, **kwargs):
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
try:
cache_key = gen_cache_key(*messages)
cache_key = gen_cache_key(messages, model, tools)
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
return cached_response.decode('utf-8')
return cached_response.decode("utf-8")
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, *args, **kwargs)
if redis_client:
result = func(self, model, messages, stream, tools, *args, **kwargs)
if redis_client and isinstance(result, str):
try:
redis_client.set(cache_key, result, ex=1800)
except redis.ConnectionError as e:
@@ -55,20 +63,22 @@ def gen_cache(func):
except ValueError as e:
logger.error(e)
return "Error: No user message found in the conversation to generate a cache key."
return wrapper
def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(*messages)
cache_key = gen_cache_key(messages)
logger.info(f"Stream cache key: {cache_key}")
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
logger.info(f"Cache hit for stream key: {cache_key}")
cached_response = json.loads(cached_response.decode('utf-8'))
cached_response = json.loads(cached_response.decode("utf-8"))
for chunk in cached_response:
yield chunk
time.sleep(0.03)
@@ -78,16 +88,16 @@ def stream_cache(func):
result = func(self, model, messages, stream, *args, **kwargs)
stream_cache_data = []
for chunk in result:
stream_cache_data.append(chunk)
yield chunk
if redis_client:
try:
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
logger.info(f"Stream cache saved for key: {cache_key}")
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return wrapper
return wrapper

View File

@@ -16,7 +16,7 @@ class Settings(BaseSettings):
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"

View File

@@ -17,7 +17,7 @@ class AnthropicLLM(BaseLLM):
self.AI_PROMPT = AI_PROMPT
def _raw_gen(
self, baseself, model, messages, stream=False, max_tokens=300, **kwargs
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
@@ -34,7 +34,7 @@ class AnthropicLLM(BaseLLM):
return completion.completion
def _raw_gen_stream(
self, baseself, model, messages, stream=True, max_tokens=300, **kwargs
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]

View File

@@ -13,12 +13,12 @@ class BaseLLM(ABC):
return method(self, *args, **kwargs)
@abstractmethod
def _raw_gen(self, model, messages, stream, *args, **kwargs):
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
pass
def gen(self, model, messages, stream=False, *args, **kwargs):
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
@@ -26,4 +26,10 @@ class BaseLLM(ABC):
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
def supports_tools(self):
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")

View File

@@ -1,45 +1,32 @@
from application.llm.base import BaseLLM
from openai import OpenAI
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
**kwargs
):
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -25,14 +25,20 @@ class OpenAILLM(BaseLLM):
model,
messages,
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
@@ -40,6 +46,7 @@ class OpenAILLM(BaseLLM):
model,
messages,
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
@@ -52,6 +59,9 @@ class OpenAILLM(BaseLLM):
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content
def _supports_tools(self):
return True
class AzureOpenAILLM(OpenAILLM):

View File

@@ -76,7 +76,7 @@ class SagemakerAPILLM(BaseLLM):
self.endpoint = settings.SAGEMAKER_ENDPOINT
self.runtime = runtime
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@@ -105,7 +105,7 @@ class SagemakerAPILLM(BaseLLM):
print(result[0]["generated_text"], file=sys.stderr)
return result[0]["generated_text"][len(prompt) :]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

View File

@@ -0,0 +1,118 @@
import re
from typing import List, Tuple
import logging
from application.parser.schema.base import Document
from application.utils import get_encoding
logger = logging.getLogger(__name__)
class Chunker:
def __init__(
self,
chunking_strategy: str = "classic_chunk",
max_tokens: int = 2000,
min_tokens: int = 150,
duplicate_headers: bool = False,
):
if chunking_strategy not in ["classic_chunk"]:
raise ValueError(f"Unsupported chunking strategy: {chunking_strategy}")
self.chunking_strategy = chunking_strategy
self.max_tokens = max_tokens
self.min_tokens = min_tokens
self.duplicate_headers = duplicate_headers
self.encoding = get_encoding()
def separate_header_and_body(self, text: str) -> Tuple[str, str]:
header_pattern = r"^(.*?\n){3}"
match = re.match(header_pattern, text)
if match:
header = match.group(0)
body = text[len(header):]
else:
header, body = "", text # No header, treat entire text as body
return header, body
def combine_documents(self, doc: Document, next_doc: Document) -> Document:
combined_text = doc.text + " " + next_doc.text
combined_token_count = len(self.encoding.encode(combined_text))
new_doc = Document(
text=combined_text,
doc_id=doc.doc_id,
embedding=doc.embedding,
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count}
)
return new_doc
def split_document(self, doc: Document) -> List[Document]:
split_docs = []
header, body = self.separate_header_and_body(doc.text)
header_tokens = self.encoding.encode(header) if header else []
body_tokens = self.encoding.encode(body)
current_position = 0
part_index = 0
while current_position < len(body_tokens):
end_position = current_position + self.max_tokens - len(header_tokens)
chunk_tokens = (header_tokens + body_tokens[current_position:end_position]
if self.duplicate_headers or part_index == 0 else body_tokens[current_position:end_position])
chunk_text = self.encoding.decode(chunk_tokens)
new_doc = Document(
text=chunk_text,
doc_id=f"{doc.doc_id}-{part_index}",
embedding=doc.embedding,
extra_info={**(doc.extra_info or {}), "token_count": len(chunk_tokens)}
)
split_docs.append(new_doc)
current_position = end_position
part_index += 1
header_tokens = []
return split_docs
def classic_chunk(self, documents: List[Document]) -> List[Document]:
processed_docs = []
i = 0
while i < len(documents):
doc = documents[i]
tokens = self.encoding.encode(doc.text)
token_count = len(tokens)
if self.min_tokens <= token_count <= self.max_tokens:
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
elif token_count < self.min_tokens:
if i + 1 < len(documents):
next_doc = documents[i + 1]
next_tokens = self.encoding.encode(next_doc.text)
if token_count + len(next_tokens) <= self.max_tokens:
# Combine small documents
combined_doc = self.combine_documents(doc, next_doc)
processed_docs.append(combined_doc)
i += 2
else:
# Keep the small document as is if adding next_doc would exceed max_tokens
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
else:
# No next document to combine with; add the small document as is
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
else:
# Split large documents
processed_docs.extend(self.split_document(doc))
i += 1
return processed_docs
def chunk(
self,
documents: List[Document]
) -> List[Document]:
if self.chunking_strategy == "classic_chunk":
return self.classic_chunk(documents)
else:
raise ValueError("Unsupported chunking strategy")

View File

@@ -0,0 +1,86 @@
import os
import logging
from retry import retry
from tqdm import tqdm
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
@retry(tries=10, delay=60)
def add_text_to_store_with_retry(store, doc, source_id):
"""
Add a document's text and metadata to the vector store with retry logic.
Args:
store: The vector store object.
doc: The document to be added.
source_id: Unique identifier for the source.
"""
try:
doc.metadata["source_id"] = str(source_id)
store.add_texts([doc.page_content], metadatas=[doc.metadata])
except Exception as e:
logging.error(f"Failed to add document with retry: {e}")
raise
def embed_and_store_documents(docs, folder_name, source_id, task_status):
"""
Embeds documents and stores them in a vector store.
Args:
docs (list): List of documents to be embedded and stored.
folder_name (str): Directory to save the vector store.
source_id (str): Unique identifier for the source.
task_status: Task state manager for progress updates.
Returns:
None
"""
# Ensure the folder exists
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# Initialize vector store
if settings.VECTOR_STORE == "faiss":
docs_init = [docs.pop(0)]
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
docs_init=docs_init,
source_id=folder_name,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
else:
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
store.delete_index()
total_docs = len(docs)
# Process and embed documents
for idx, doc in tqdm(
enumerate(docs),
desc="Embedding 🦖",
unit="docs",
total=total_docs,
bar_format="{l_bar}{bar}| Time Left: {remaining}",
):
try:
# Update task status for progress tracking
progress = int(((idx + 1) / total_docs) * 100)
task_status.update_state(state="PROGRESS", meta={"current": progress})
# Add document to vector store
add_text_to_store_with_retry(store, doc, source_id)
except Exception as e:
logging.error(f"Error embedding document {idx}: {e}")
logging.info(f"Saving progress at document {idx} out of {total_docs}")
store.save_local(folder_name)
break
# Save the vector store
if settings.VECTOR_STORE == "faiss":
store.save_local(folder_name)
logging.info("Vector store saved successfully.")

View File

@@ -1,75 +0,0 @@
import os
from retry import retry
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
# from langchain_community.embeddings import HuggingFaceEmbeddings
# from langchain_community.embeddings import HuggingFaceInstructEmbeddings
# from langchain_community.embeddings import CohereEmbeddings
@retry(tries=10, delay=60)
def store_add_texts_with_retry(store, i, id):
# add source_id to the metadata
i.metadata["source_id"] = str(id)
store.add_texts([i.page_content], metadatas=[i.metadata])
# store_pine.add_texts([i.page_content], metadatas=[i.metadata])
def call_openai_api(docs, folder_name, id, task_status):
# Function to create a vector store from the documents and save it to disk
if not os.path.exists(f"{folder_name}"):
os.makedirs(f"{folder_name}")
from tqdm import tqdm
c1 = 0
if settings.VECTOR_STORE == "faiss":
docs_init = [docs[0]]
docs.pop(0)
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
docs_init=docs_init,
source_id=f"{folder_name}",
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
else:
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id=str(id),
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
store.delete_index()
# Uncomment for MPNet embeddings
# model_name = "sentence-transformers/all-mpnet-base-v2"
# hf = HuggingFaceEmbeddings(model_name=model_name)
# store = FAISS.from_documents(docs_test, hf)
s1 = len(docs)
for i in tqdm(
docs,
desc="Embedding 🦖",
unit="docs",
total=len(docs),
bar_format="{l_bar}{bar}| Time Left: {remaining}",
):
try:
task_status.update_state(
state="PROGRESS", meta={"current": int((c1 / s1) * 100)}
)
store_add_texts_with_retry(store, i, id)
except Exception as e:
print(e)
print("Error on ", i)
print("Saving progress")
print(f"stopped at {c1} out of {len(docs)}")
store.save_local(f"{folder_name}")
break
c1 += 1
if settings.VECTOR_STORE == "faiss":
store.save_local(f"{folder_name}")

View File

@@ -1,79 +0,0 @@
import re
from math import ceil
from typing import List
import tiktoken
from application.parser.schema.base import Document
def separate_header_and_body(text):
header_pattern = r"^(.*?\n){3}"
match = re.match(header_pattern, text)
header = match.group(0)
body = text[len(header):]
return header, body
def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) -> List[Document]:
docs = []
current_group = None
for doc in documents:
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
# Check if current group is empty or if the document can be added based on token count and matching metadata
if (current_group is None or
(len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and
doc_len < min_tokens and
current_group.extra_info == doc.extra_info)):
if current_group is None:
current_group = doc # Use the document directly to retain its metadata
else:
current_group.text += " " + doc.text # Append text to the current group
else:
docs.append(current_group)
current_group = doc # Start a new group with the current document
if current_group is not None:
docs.append(current_group)
return docs
def split_documents(documents: List[Document], max_tokens: int) -> List[Document]:
docs = []
for doc in documents:
token_length = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
if token_length <= max_tokens:
docs.append(doc)
else:
header, body = separate_header_and_body(doc.text)
if len(tiktoken.get_encoding("cl100k_base").encode(header)) > max_tokens:
body = doc.text
header = ""
num_body_parts = ceil(token_length / max_tokens)
part_length = ceil(len(body) / num_body_parts)
body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)]
for i, body_part in enumerate(body_parts):
new_doc = Document(text=header + body_part.strip(),
doc_id=f"{doc.doc_id}-{i}",
embedding=doc.embedding,
extra_info=doc.extra_info)
docs.append(new_doc)
return docs
def group_split(documents: List[Document], max_tokens: int = 2000, min_tokens: int = 150, token_check: bool = True):
if not token_check:
return documents
print("Grouping small documents")
try:
documents = group_documents(documents=documents, min_tokens=min_tokens, max_tokens=max_tokens)
except Exception:
print("Grouping failed, try running without token_check")
print("Separating large documents")
try:
documents = split_documents(documents=documents, max_tokens=max_tokens)
except Exception:
print("Grouping failed, try running without token_check")
return documents

View File

@@ -1,24 +1,24 @@
anthropic==0.40.0
boto3==1.34.153
beautifulsoup4==4.12.3
celery==5.3.6
celery==5.4.0
dataclasses-json==0.6.7
docx2txt==0.8
duckduckgo-search==6.3.0
ebooklib==0.18
elastic-transport==8.15.0
elasticsearch==8.15.1
elastic-transport==8.15.1
elasticsearch==8.17.0
escodegen==1.0.11
esprima==4.0.1
esutils==1.0.1
Flask==3.0.3
faiss-cpu==1.8.0.post1
faiss-cpu==1.9.0.post1
flask-restx==1.3.0
gTTS==2.3.2
gunicorn==23.0.0
html2text==2024.2.26
javalang==0.13.0
jinja2==3.1.4
jinja2==3.1.5
jiter==0.5.0
jmespath==1.0.1
joblib==1.4.2
@@ -28,22 +28,22 @@ jsonschema==4.23.0
jsonschema-spec==0.2.4
jsonschema-specifications==2023.7.1
kombu==5.4.2
langchain==0.3.11
langchain-community==0.3.11
langchain-core==0.3.25
langchain-openai==0.2.0
langchain-text-splitters==0.3.0
langsmith==0.2.3
langchain==0.3.13
langchain-community==0.3.13
langchain-core==0.3.28
langchain-openai==0.2.14
langchain-text-splitters==0.3.4
langsmith==0.2.6
lazy-object-proxy==1.10.0
lxml==5.3.0
markupsafe==2.1.5
marshmallow==3.22.0
marshmallow==3.23.2
mpmath==1.3.0
multidict==6.1.0
mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
openai==1.55.3
numpy==2.2.1
openai==1.58.1
openapi-schema-validator==0.6.2
openapi-spec-validator==0.6.0
openapi3-parser==1.1.18
@@ -68,13 +68,13 @@ python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-pptx==1.0.2
qdrant-client==1.11.0
redis==5.0.1
redis==5.2.1
referencing==0.30.2
regex==2024.9.11
requests==2.32.3
retry==0.9.2
sentence-transformers==3.0.1
tiktoken==0.7.0
sentence-transformers==3.3.1
tiktoken==0.8.0
tokenizers==0.21.0
torch==2.4.1
tqdm==4.66.5
@@ -86,4 +86,4 @@ urllib3==2.2.3
vine==5.1.0
wcwidth==0.2.13
werkzeug==3.1.3
yarl==1.11.1
yarl==1.18.3

View File

@@ -2,7 +2,6 @@ import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
from langchain_community.tools import BraveSearch
@@ -73,15 +72,8 @@ class BraveRetSearch(BaseRetriever):
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)

View File

@@ -1,9 +1,9 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
from application.vectorstore.vector_creator import VectorCreator
from application.utils import num_tokens_from_string
class ClassicRAG(BaseRetriever):
@@ -20,7 +20,7 @@ class ClassicRAG(BaseRetriever):
user_api_key=None,
):
self.question = question
self.vectorstore = source['active_docs'] if 'active_docs' in source else None
self.vectorstore = source["active_docs"] if "active_docs" in source else None
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
@@ -73,15 +73,8 @@ class ClassicRAG(BaseRetriever):
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
@@ -89,17 +82,23 @@ class ClassicRAG(BaseRetriever):
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
completion = agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()
def get_params(self):
return {
"question": self.question,
@@ -109,5 +108,5 @@ class ClassicRAG(BaseRetriever):
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
"user_api_key": self.user_api_key,
}

View File

@@ -1,7 +1,6 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
@@ -89,16 +88,9 @@ class DuckDuckSearch(BaseRetriever):
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)

149
application/tools/agent.py Normal file
View File

@@ -0,0 +1,149 @@
import json
from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
from application.tools.tool_manager import ToolManager
class Agent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
# Initialize the LLM with the provided parameters
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = []
self.tool_config = {}
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
return tools_by_id
def _prepare_tools(self, tools_dict):
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": {
**action["parameters"],
"properties": {
k: {
key: value
for key, value in v.items()
if key != "filled_by_llm" and key != "value"
}
for k, v in action["parameters"]["properties"].items()
if v.get("filled_by_llm", False)
},
"required": [
key
for key in action["parameters"]["required"]
if key in action["parameters"]["properties"]
and action["parameters"]["properties"][key].get(
"filled_by_llm", False
)
],
},
},
}
for tool_id, tool in tools_dict.items()
for action in tool["actions"]
if action["active"]
]
def _execute_tool_action(self, tools_dict, call):
call_id = call.id
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
tool_data = tools_dict[tool_id]
action_data = next(
action for action in tool_data["actions"] if action["name"] == action_name
)
for param, details in action_data["parameters"]["properties"].items():
if param not in call_args and "value" in details:
call_args[param] = details["value"]
tm = ToolManager(config={})
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
print(f"Executing tool: {action_name} with args: {call_args}")
return tool.execute_action(action_name, **call_args), call_id
def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
if isinstance(resp, str):
yield resp
return
if resp.message.content:
yield resp.message.content
return
while resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = self._execute_tool_action(tools_dict, call)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
)
# Generate a new response from the LLM after processing tools
resp = self.llm.gen(
model=self.gpt_model, messages=messages, tools=self.tools
)
# If no tool calls are needed, generate the final response
if isinstance(resp, str):
yield resp
elif resp.message.content:
yield resp.message.content
else:
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools
)
for line in completion:
yield line
return
def gen(self, messages):
# Generate initial response from the LLM
if self.llm.supports_tools():
resp = self._simple_tool_agent(messages)
for line in resp:
yield line
else:
resp = self.llm.gen_stream(model=self.gpt_model, messages=messages)
for line in resp:
yield line

21
application/tools/base.py Normal file
View File

@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
class Tool(ABC):
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass
@abstractmethod
def get_actions_metadata(self):
"""
Returns a list of JSON objects describing the actions supported by the tool.
"""
pass
@abstractmethod
def get_config_requirements(self):
"""
Returns a dictionary describing the configuration requirements for the tool.
"""
pass

View File

@@ -0,0 +1,77 @@
import requests
from application.tools.base import Tool
class CryptoPriceTool(Tool):
"""
CryptoPrice
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {"cryptoprice_get": self._get_price}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _get_price(self, symbol, currency):
"""
Fetches the current price of a given cryptocurrency symbol in the specified currency.
Example:
symbol = "BTC"
currency = "USD"
returns price in USD.
"""
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
# data will be like {"USD": <price>} if the call is successful
if currency.upper() in data:
return {
"status_code": response.status_code,
"price": data[currency.upper()],
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
}
else:
return {
"status_code": response.status_code,
"message": "Failed to retrieve price.",
}
def get_actions_metadata(self):
return [
{
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False,
},
}
]
def get_config_requirements(self):
# No specific configuration needed for this tool as it just queries a public endpoint
return {}

View File

@@ -0,0 +1,86 @@
import requests
from application.tools.base import Tool
class TelegramTool(Tool):
"""
Telegram Bot
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
Requires a bot token and chat ID for configuration
"""
def __init__(self, config):
self.config = config
self.token = config.get("token", "")
def execute_action(self, action_name, **kwargs):
actions = {
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _send_message(self, text, chat_id):
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Image sent"}
def get_actions_metadata(self):
return [
{
"name": "telegram_send_message",
"description": "Send a notification to Telegram chat",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification",
},
"chat_id": {
"type": "string",
"description": "Chat ID to send the notification to",
},
},
"required": ["text"],
"additionalProperties": False,
},
},
{
"name": "telegram_send_image",
"description": "Send an image to the Telegram chat",
"parameters": {
"type": "object",
"properties": {
"image_url": {
"type": "string",
"description": "URL of the image to send",
},
"chat_id": {
"type": "string",
"description": "Chat ID to send the image to",
},
},
"required": ["image_url"],
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
return {
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -0,0 +1,46 @@
import importlib
import inspect
import os
import pkgutil
from application.tools.base import Tool
class ToolManager:
def __init__(self, config):
self.config = config
self.tools = {}
self.load_tools()
def load_tools(self):
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == "base" or name.startswith("__"):
continue
module = importlib.import_module(
f"application.tools.implementations.{name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
module = importlib.import_module(
f"application.tools.implementations.{tool_name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):
metadata = []
for tool in self.tools.values():
metadata.extend(tool.get_actions_metadata())
return metadata

View File

@@ -1,7 +1,7 @@
import sys
from datetime import datetime
from application.core.mongo_db import MongoDB
from application.utils import num_tokens_from_string
from application.utils import num_tokens_from_string, num_tokens_from_object_or_list
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -21,11 +21,16 @@ def update_token_usage(user_api_key, token_usage):
def gen_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, **kwargs)
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
if message["content"]:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, tools, **kwargs)
# check if result is a string
if isinstance(result, str):
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
else:
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result)
update_token_usage(self.user_api_key, self.token_usage)
return result
@@ -33,11 +38,11 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
batch = []
result = func(self, model, messages, stream, **kwargs)
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r

View File

@@ -15,9 +15,21 @@ def get_encoding():
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
num_tokens = len(encoding.encode(string))
return num_tokens
if isinstance(string, str):
num_tokens = len(encoding.encode(string))
return num_tokens
else:
return 0
def num_tokens_from_object_or_list(thing):
if isinstance(thing, list):
return sum([num_tokens_from_object_or_list(x) for x in thing])
elif isinstance(thing, dict):
return sum([num_tokens_from_object_or_list(x) for x in thing.values()])
elif isinstance(thing, str):
return num_tokens_from_string(thing)
else:
return 0
def count_tokens_docs(docs):
docs_content = ""
@@ -46,3 +58,40 @@ def check_required_fields(data, required_fields):
def get_hash(data):
return hashlib.md5(data.encode()).hexdigest()
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
"""
Limits chat history based on token count.
Returns a list of messages that fit within the token limit.
"""
from application.core.settings import settings
max_token_limit = (
max_token_limit
if max_token_limit and
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
if not history:
return []
tokens_current_history = 0
trimmed_history = []
for message in reversed(history):
if "prompt" in message and "response" in message:
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(
message["response"]
)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)
else:
break
return trimmed_history

View File

@@ -12,10 +12,10 @@ from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.file.bulk import SimpleDirectoryReader
from application.parser.open_ai_func import call_openai_api
from application.parser.embedding_pipeline import embed_and_store_documents
from application.parser.remote.remote_creator import RemoteCreator
from application.parser.schema.base import Document
from application.parser.token_func import group_split
from application.parser.chunking import Chunker
from application.utils import count_tokens_docs
mongo = MongoDB.get_client()
@@ -126,7 +126,6 @@ def ingest_worker(
limit = None
exclude = True
sample = False
token_check = True
full_path = os.path.join(directory, user, name_job)
logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": name_job})
@@ -153,17 +152,19 @@ def ingest_worker(
exclude_hidden=exclude,
file_metadata=metadata_from_filename,
).load_data()
raw_docs = group_split(
documents=raw_docs,
min_tokens=MIN_TOKENS,
chunker = Chunker(
chunking_strategy="classic_chunk",
max_tokens=MAX_TOKENS,
token_check=token_check,
min_tokens=MIN_TOKENS,
duplicate_headers=False
)
raw_docs = chunker.chunk(documents=raw_docs)
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
id = ObjectId()
call_openai_api(docs, full_path, id, self)
embed_and_store_documents(docs, full_path, id, self)
tokens = count_tokens_docs(docs)
self.update_state(state="PROGRESS", meta={"current": 100})
@@ -203,7 +204,6 @@ def remote_worker(
operation_mode="upload",
doc_id=None,
):
token_check = True
full_path = os.path.join(directory, user, name_job)
if not os.path.exists(full_path):
@@ -217,21 +217,23 @@ def remote_worker(
remote_loader = RemoteCreator.create_loader(loader)
raw_docs = remote_loader.load_data(source_data)
docs = group_split(
documents=raw_docs,
min_tokens=MIN_TOKENS,
chunker = Chunker(
chunking_strategy="classic_chunk",
max_tokens=MAX_TOKENS,
token_check=token_check,
min_tokens=MIN_TOKENS,
duplicate_headers=False
)
docs = chunker.chunk(documents=raw_docs)
tokens = count_tokens_docs(docs)
if operation_mode == "upload":
id = ObjectId()
call_openai_api(docs, full_path, id, self)
embed_and_store_documents(docs, full_path, id, self)
elif operation_mode == "sync":
if not doc_id or not ObjectId.is_valid(doc_id):
raise ValueError("doc_id must be provided for sync operation.")
id = ObjectId(doc_id)
call_openai_api(docs, full_path, id, self)
embed_and_store_documents(docs, full_path, id, self)
self.update_state(state="PROGRESS", meta={"current": 100})
file_data = {