mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge branch 'arc53:main' into basic-ui
This commit is contained in:
4
.github/workflows/pytest.yml
vendored
4
.github/workflows/pytest.yml
vendored
@@ -6,7 +6,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
run: |
|
||||
python -m pytest --cov=application --cov-report=xml
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: github.event_name == 'pull_request' && matrix.python-version == '3.11'
|
||||
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
38
.vscode/launch.json
vendored
38
.vscode/launch.json
vendored
@@ -11,6 +11,44 @@
|
||||
"skipFiles": [
|
||||
"<node_internals>/**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Flask Debugger",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "application/app.py",
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
"FLASK_ENV": "development",
|
||||
"FLASK_DEBUG": "1",
|
||||
"FLASK_RUN_PORT": "7091",
|
||||
"FLASK_RUN_HOST": "0.0.0.0"
|
||||
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--no-debugger"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
},
|
||||
{
|
||||
"name": "Celery Debugger",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}",
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"application.app.celery",
|
||||
"worker",
|
||||
"-l",
|
||||
"INFO",
|
||||
"--pool=solo"
|
||||
],
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -8,14 +8,14 @@ RUN apt-get update && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
# Install necessary packages and Python
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.11 python3.11-distutils python3.11-venv && \
|
||||
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
RUN if [ -f /usr/bin/python3.11 ]; then \
|
||||
ln -s /usr/bin/python3.11 /usr/bin/python; \
|
||||
RUN if [ -f /usr/bin/python3.12 ]; then \
|
||||
ln -s /usr/bin/python3.12 /usr/bin/python; \
|
||||
else \
|
||||
echo "Python 3.11 not found"; exit 1; \
|
||||
echo "Python 3.12 not found"; exit 1; \
|
||||
fi
|
||||
|
||||
# Download and unzip the model
|
||||
@@ -33,7 +33,7 @@ RUN apt-get remove --purge -y wget unzip && apt-get autoremove -y && rm -rf /var
|
||||
COPY requirements.txt .
|
||||
|
||||
# Setup Python virtual environment
|
||||
RUN python3.11 -m venv /venv
|
||||
RUN python3.12 -m venv /venv
|
||||
|
||||
# Activate virtual environment and install Python packages
|
||||
ENV PATH="/venv/bin:$PATH"
|
||||
@@ -50,8 +50,8 @@ RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
# Install Python
|
||||
apt-get update && apt-get install -y --no-install-recommends python3.11 && \
|
||||
ln -s /usr/bin/python3.11 /usr/bin/python && \
|
||||
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
|
||||
ln -s /usr/bin/python3.12 /usr/bin/python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set working directory
|
||||
|
||||
@@ -18,7 +18,7 @@ from application.error import bad_request
|
||||
from application.extensions import api
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.utils import check_required_fields, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,7 +37,7 @@ api.add_namespace(answer_ns)
|
||||
gpt_model = ""
|
||||
# to have some kind of default behaviour
|
||||
if settings.LLM_NAME == "openai":
|
||||
gpt_model = "gpt-3.5-turbo"
|
||||
gpt_model = "gpt-4o-mini"
|
||||
elif settings.LLM_NAME == "anthropic":
|
||||
gpt_model = "claude-2"
|
||||
elif settings.LLM_NAME == "groq":
|
||||
@@ -324,8 +324,7 @@ class Stream(Resource):
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = str(data.get("history", []))
|
||||
history = str(json.loads(history))
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
|
||||
@@ -456,7 +455,7 @@ class Answer(Resource):
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import math
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import Blueprint, jsonify, make_response, request, redirect
|
||||
from flask_restx import inputs, fields, Namespace, Resource
|
||||
from flask import Blueprint, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, inputs, Namespace, Resource
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.api.user.tasks import ingest, ingest_remote
|
||||
@@ -16,9 +16,10 @@ from application.api.user.tasks import ingest, ingest_remote
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.extensions import api
|
||||
from application.tools.tool_manager import ToolManager
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.utils import check_required_fields
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
@@ -30,6 +31,7 @@ api_key_collection = db["api_keys"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
|
||||
user = Blueprint("user", __name__)
|
||||
user_ns = Namespace("user", description="User related operations", path="/")
|
||||
@@ -39,6 +41,9 @@ current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
def generate_minute_range(start_date, end_date):
|
||||
return {
|
||||
@@ -176,10 +181,12 @@ class SubmitFeedback(Resource):
|
||||
"FeedbackModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The user question"
|
||||
required=False, description="The user question"
|
||||
),
|
||||
"answer": fields.String(required=True, description="The AI answer"),
|
||||
"answer": fields.String(required=False, description="The AI answer"),
|
||||
"feedback": fields.String(required=True, description="User feedback"),
|
||||
"question_index":fields.Integer(required=True, description="The question number in that particular conversation"),
|
||||
"conversation_id":fields.String(required=True, description="id of the particular conversation"),
|
||||
"api_key": fields.String(description="Optional API key"),
|
||||
},
|
||||
)
|
||||
@@ -189,23 +196,21 @@ class SubmitFeedback(Resource):
|
||||
)
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question", "answer", "feedback"]
|
||||
required_fields = [ "feedback","conversation_id","question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
new_doc = {
|
||||
"question": data["question"],
|
||||
"answer": data["answer"],
|
||||
"feedback": data["feedback"],
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
|
||||
if "api_key" in data:
|
||||
new_doc["api_key"] = data["api_key"]
|
||||
|
||||
try:
|
||||
feedback_collection.insert_one(new_doc)
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["conversation_id"]), f"queries.{data['question_index']}": {"$exists": True}},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data["feedback"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
@@ -1802,3 +1807,295 @@ class TextToSpeech(Resource):
|
||||
)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
@user_ns.route("/api/available_tools")
|
||||
class AvailableTools(Resource):
|
||||
@api.doc(description="Get available tools for a user")
|
||||
def get(self):
|
||||
try:
|
||||
tools_metadata = []
|
||||
for tool_name, tool_instance in tool_manager.tools.items():
|
||||
doc = tool_instance.__doc__.strip()
|
||||
lines = doc.split("\n", 1)
|
||||
name = lines[0].strip()
|
||||
description = lines[1].strip() if len(lines) > 1 else ""
|
||||
tools_metadata.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"displayName": name,
|
||||
"description": description,
|
||||
"configRequirements": tool_instance.get_config_requirements(),
|
||||
"actions": tool_instance.get_actions_metadata(),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/get_tools")
|
||||
class GetTools(Resource):
|
||||
@api.doc(description="Get tools created by a user")
|
||||
def get(self):
|
||||
try:
|
||||
user = "local"
|
||||
tools = user_tools_collection.find({"user": user})
|
||||
user_tools = []
|
||||
for tool in tools:
|
||||
tool["id"] = str(tool["_id"])
|
||||
tool.pop("_id")
|
||||
user_tools.append(tool)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/create_tool")
|
||||
class CreateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateToolModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="Name of the tool"),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the tool"
|
||||
),
|
||||
"description": fields.String(
|
||||
required=True, description="Tool description"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
"actions": fields.List(
|
||||
fields.Raw,
|
||||
required=True,
|
||||
description="Actions the tool can perform",
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create a new tool")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = [
|
||||
"name",
|
||||
"displayName",
|
||||
"description",
|
||||
"actions",
|
||||
"config",
|
||||
"status",
|
||||
]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
user = "local"
|
||||
transformed_actions = []
|
||||
for action in data["actions"]:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
for param_name, param_details in action["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
try:
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"displayName": data["displayName"],
|
||||
"description": data["description"],
|
||||
"actions": transformed_actions,
|
||||
"config": data["config"],
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/update_tool")
|
||||
class UpdateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"name": fields.String(description="Name of the tool"),
|
||||
"displayName": fields.String(description="Display name for the tool"),
|
||||
"description": fields.String(description="Tool description"),
|
||||
"config": fields.Raw(description="Configuration of the tool"),
|
||||
"actions": fields.List(
|
||||
fields.Raw, description="Actions the tool can perform"
|
||||
),
|
||||
"status": fields.Boolean(description="Status of the tool"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update a tool by ID")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
update_data = {}
|
||||
if "name" in data:
|
||||
update_data["name"] = data["name"]
|
||||
if "displayName" in data:
|
||||
update_data["displayName"] = data["displayName"]
|
||||
if "description" in data:
|
||||
update_data["description"] = data["description"]
|
||||
if "actions" in data:
|
||||
update_data["actions"] = data["actions"]
|
||||
if "config" in data:
|
||||
update_data["config"] = data["config"]
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": "local"},
|
||||
{"$set": update_data},
|
||||
)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/update_tool_config")
|
||||
class UpdateToolConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolConfigModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the configuration of a tool")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])},
|
||||
{"$set": {"config": data["config"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/update_tool_actions")
|
||||
class UpdateToolActions(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolActionsModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"actions": fields.List(
|
||||
fields.Raw,
|
||||
required=True,
|
||||
description="Actions the tool can perform",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the actions of a tool")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "actions"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])},
|
||||
{"$set": {"actions": data["actions"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/update_tool_status")
|
||||
class UpdateToolStatus(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolStatusModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the status of a tool")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "status"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])},
|
||||
{"$set": {"status": data["status"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/delete_tool")
|
||||
class DeleteTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DeleteToolModel",
|
||||
{"id": fields.String(required=True, description="Tool ID")},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Delete a tool by ID")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])})
|
||||
if result.deleted_count == 0:
|
||||
return {"success": False, "message": "Tool not found"}, 404
|
||||
except Exception as err:
|
||||
return {"success": False, "error": str(err)}, 400
|
||||
|
||||
return {"success": True}, 200
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import redis
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
import redis
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.utils import get_hash
|
||||
|
||||
@@ -11,41 +13,47 @@ logger = logging.getLogger(__name__)
|
||||
_redis_instance = None
|
||||
_instance_lock = Lock()
|
||||
|
||||
|
||||
def get_redis_instance():
|
||||
global _redis_instance
|
||||
if _redis_instance is None:
|
||||
with _instance_lock:
|
||||
if _redis_instance is None:
|
||||
try:
|
||||
_redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2)
|
||||
_redis_instance = redis.Redis.from_url(
|
||||
settings.CACHE_REDIS_URL, socket_connect_timeout=2
|
||||
)
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
_redis_instance = None
|
||||
return _redis_instance
|
||||
|
||||
def gen_cache_key(*messages, model="docgpt"):
|
||||
|
||||
def gen_cache_key(messages, model="docgpt", tools=None):
|
||||
if not all(isinstance(msg, dict) for msg in messages):
|
||||
raise ValueError("All messages must be dictionaries.")
|
||||
messages_str = json.dumps(list(messages), sort_keys=True)
|
||||
combined = f"{model}_{messages_str}"
|
||||
messages_str = json.dumps(messages)
|
||||
tools_str = json.dumps(tools) if tools else ""
|
||||
combined = f"{model}_{messages_str}_{tools_str}"
|
||||
cache_key = get_hash(combined)
|
||||
return cache_key
|
||||
|
||||
|
||||
def gen_cache(func):
|
||||
def wrapper(self, model, messages, *args, **kwargs):
|
||||
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
|
||||
try:
|
||||
cache_key = gen_cache_key(*messages)
|
||||
cache_key = gen_cache_key(messages, model, tools)
|
||||
redis_client = get_redis_instance()
|
||||
if redis_client:
|
||||
try:
|
||||
cached_response = redis_client.get(cache_key)
|
||||
if cached_response:
|
||||
return cached_response.decode('utf-8')
|
||||
return cached_response.decode("utf-8")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
|
||||
result = func(self, model, messages, *args, **kwargs)
|
||||
if redis_client:
|
||||
result = func(self, model, messages, stream, tools, *args, **kwargs)
|
||||
if redis_client and isinstance(result, str):
|
||||
try:
|
||||
redis_client.set(cache_key, result, ex=1800)
|
||||
except redis.ConnectionError as e:
|
||||
@@ -55,20 +63,22 @@ def gen_cache(func):
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
return "Error: No user message found in the conversation to generate a cache key."
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def stream_cache(func):
|
||||
def wrapper(self, model, messages, stream, *args, **kwargs):
|
||||
cache_key = gen_cache_key(*messages)
|
||||
cache_key = gen_cache_key(messages)
|
||||
logger.info(f"Stream cache key: {cache_key}")
|
||||
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
if redis_client:
|
||||
try:
|
||||
cached_response = redis_client.get(cache_key)
|
||||
if cached_response:
|
||||
logger.info(f"Cache hit for stream key: {cache_key}")
|
||||
cached_response = json.loads(cached_response.decode('utf-8'))
|
||||
cached_response = json.loads(cached_response.decode("utf-8"))
|
||||
for chunk in cached_response:
|
||||
yield chunk
|
||||
time.sleep(0.03)
|
||||
@@ -78,16 +88,16 @@ def stream_cache(func):
|
||||
|
||||
result = func(self, model, messages, stream, *args, **kwargs)
|
||||
stream_cache_data = []
|
||||
|
||||
|
||||
for chunk in result:
|
||||
stream_cache_data.append(chunk)
|
||||
yield chunk
|
||||
|
||||
|
||||
if redis_client:
|
||||
try:
|
||||
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
|
||||
logger.info(f"Stream cache saved for key: {cache_key}")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
|
||||
return wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -16,7 +16,7 @@ class Settings(BaseSettings):
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
|
||||
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||
|
||||
@@ -17,7 +17,7 @@ class AnthropicLLM(BaseLLM):
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
|
||||
def _raw_gen(
|
||||
self, baseself, model, messages, stream=False, max_tokens=300, **kwargs
|
||||
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
@@ -34,7 +34,7 @@ class AnthropicLLM(BaseLLM):
|
||||
return completion.completion
|
||||
|
||||
def _raw_gen_stream(
|
||||
self, baseself, model, messages, stream=True, max_tokens=300, **kwargs
|
||||
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
|
||||
@@ -13,12 +13,12 @@ class BaseLLM(ABC):
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _raw_gen(self, model, messages, stream, *args, **kwargs):
|
||||
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def gen(self, model, messages, stream=False, *args, **kwargs):
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
|
||||
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
|
||||
@@ -26,4 +26,10 @@ class BaseLLM(ABC):
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
|
||||
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
|
||||
|
||||
def supports_tools(self):
|
||||
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
|
||||
|
||||
def _supports_tools(self):
|
||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||
@@ -1,45 +1,32 @@
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class GroqLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
**kwargs
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
**kwargs
|
||||
):
|
||||
self, baseself, model, messages, stream=True, tools=None, **kwargs
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
for line in response:
|
||||
# import sys
|
||||
# print(line.choices[0].delta.content, file=sys.stderr)
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
@@ -25,14 +25,20 @@ class OpenAILLM(BaseLLM):
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
):
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
self,
|
||||
@@ -40,6 +46,7 @@ class OpenAILLM(BaseLLM):
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
@@ -52,6 +59,9 @@ class OpenAILLM(BaseLLM):
|
||||
# print(line.choices[0].delta.content, file=sys.stderr)
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
|
||||
@@ -76,7 +76,7 @@ class SagemakerAPILLM(BaseLLM):
|
||||
self.endpoint = settings.SAGEMAKER_ENDPOINT
|
||||
self.runtime = runtime
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@@ -105,7 +105,7 @@ class SagemakerAPILLM(BaseLLM):
|
||||
print(result[0]["generated_text"], file=sys.stderr)
|
||||
return result[0]["generated_text"][len(prompt) :]
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
118
application/parser/chunking.py
Normal file
118
application/parser/chunking.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
import logging
|
||||
from application.parser.schema.base import Document
|
||||
from application.utils import get_encoding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Chunker:
|
||||
def __init__(
|
||||
self,
|
||||
chunking_strategy: str = "classic_chunk",
|
||||
max_tokens: int = 2000,
|
||||
min_tokens: int = 150,
|
||||
duplicate_headers: bool = False,
|
||||
):
|
||||
if chunking_strategy not in ["classic_chunk"]:
|
||||
raise ValueError(f"Unsupported chunking strategy: {chunking_strategy}")
|
||||
self.chunking_strategy = chunking_strategy
|
||||
self.max_tokens = max_tokens
|
||||
self.min_tokens = min_tokens
|
||||
self.duplicate_headers = duplicate_headers
|
||||
self.encoding = get_encoding()
|
||||
|
||||
def separate_header_and_body(self, text: str) -> Tuple[str, str]:
|
||||
header_pattern = r"^(.*?\n){3}"
|
||||
match = re.match(header_pattern, text)
|
||||
if match:
|
||||
header = match.group(0)
|
||||
body = text[len(header):]
|
||||
else:
|
||||
header, body = "", text # No header, treat entire text as body
|
||||
return header, body
|
||||
|
||||
def combine_documents(self, doc: Document, next_doc: Document) -> Document:
|
||||
combined_text = doc.text + " " + next_doc.text
|
||||
combined_token_count = len(self.encoding.encode(combined_text))
|
||||
new_doc = Document(
|
||||
text=combined_text,
|
||||
doc_id=doc.doc_id,
|
||||
embedding=doc.embedding,
|
||||
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count}
|
||||
)
|
||||
return new_doc
|
||||
|
||||
def split_document(self, doc: Document) -> List[Document]:
|
||||
split_docs = []
|
||||
header, body = self.separate_header_and_body(doc.text)
|
||||
header_tokens = self.encoding.encode(header) if header else []
|
||||
body_tokens = self.encoding.encode(body)
|
||||
|
||||
current_position = 0
|
||||
part_index = 0
|
||||
while current_position < len(body_tokens):
|
||||
end_position = current_position + self.max_tokens - len(header_tokens)
|
||||
chunk_tokens = (header_tokens + body_tokens[current_position:end_position]
|
||||
if self.duplicate_headers or part_index == 0 else body_tokens[current_position:end_position])
|
||||
chunk_text = self.encoding.decode(chunk_tokens)
|
||||
new_doc = Document(
|
||||
text=chunk_text,
|
||||
doc_id=f"{doc.doc_id}-{part_index}",
|
||||
embedding=doc.embedding,
|
||||
extra_info={**(doc.extra_info or {}), "token_count": len(chunk_tokens)}
|
||||
)
|
||||
split_docs.append(new_doc)
|
||||
current_position = end_position
|
||||
part_index += 1
|
||||
header_tokens = []
|
||||
return split_docs
|
||||
|
||||
def classic_chunk(self, documents: List[Document]) -> List[Document]:
|
||||
processed_docs = []
|
||||
i = 0
|
||||
while i < len(documents):
|
||||
doc = documents[i]
|
||||
tokens = self.encoding.encode(doc.text)
|
||||
token_count = len(tokens)
|
||||
|
||||
if self.min_tokens <= token_count <= self.max_tokens:
|
||||
doc.extra_info = doc.extra_info or {}
|
||||
doc.extra_info["token_count"] = token_count
|
||||
processed_docs.append(doc)
|
||||
i += 1
|
||||
elif token_count < self.min_tokens:
|
||||
if i + 1 < len(documents):
|
||||
next_doc = documents[i + 1]
|
||||
next_tokens = self.encoding.encode(next_doc.text)
|
||||
if token_count + len(next_tokens) <= self.max_tokens:
|
||||
# Combine small documents
|
||||
combined_doc = self.combine_documents(doc, next_doc)
|
||||
processed_docs.append(combined_doc)
|
||||
i += 2
|
||||
else:
|
||||
# Keep the small document as is if adding next_doc would exceed max_tokens
|
||||
doc.extra_info = doc.extra_info or {}
|
||||
doc.extra_info["token_count"] = token_count
|
||||
processed_docs.append(doc)
|
||||
i += 1
|
||||
else:
|
||||
# No next document to combine with; add the small document as is
|
||||
doc.extra_info = doc.extra_info or {}
|
||||
doc.extra_info["token_count"] = token_count
|
||||
processed_docs.append(doc)
|
||||
i += 1
|
||||
else:
|
||||
# Split large documents
|
||||
processed_docs.extend(self.split_document(doc))
|
||||
i += 1
|
||||
return processed_docs
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
documents: List[Document]
|
||||
) -> List[Document]:
|
||||
if self.chunking_strategy == "classic_chunk":
|
||||
return self.classic_chunk(documents)
|
||||
else:
|
||||
raise ValueError("Unsupported chunking strategy")
|
||||
86
application/parser/embedding_pipeline.py
Executable file
86
application/parser/embedding_pipeline.py
Executable file
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import logging
|
||||
from retry import retry
|
||||
from tqdm import tqdm
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
@retry(tries=10, delay=60)
|
||||
def add_text_to_store_with_retry(store, doc, source_id):
|
||||
"""
|
||||
Add a document's text and metadata to the vector store with retry logic.
|
||||
Args:
|
||||
store: The vector store object.
|
||||
doc: The document to be added.
|
||||
source_id: Unique identifier for the source.
|
||||
"""
|
||||
try:
|
||||
doc.metadata["source_id"] = str(source_id)
|
||||
store.add_texts([doc.page_content], metadatas=[doc.metadata])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to add document with retry: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def embed_and_store_documents(docs, folder_name, source_id, task_status):
|
||||
"""
|
||||
Embeds documents and stores them in a vector store.
|
||||
|
||||
Args:
|
||||
docs (list): List of documents to be embedded and stored.
|
||||
folder_name (str): Directory to save the vector store.
|
||||
source_id (str): Unique identifier for the source.
|
||||
task_status: Task state manager for progress updates.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Ensure the folder exists
|
||||
if not os.path.exists(folder_name):
|
||||
os.makedirs(folder_name)
|
||||
|
||||
# Initialize vector store
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
docs_init = [docs.pop(0)]
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
docs_init=docs_init,
|
||||
source_id=folder_name,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
else:
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
store.delete_index()
|
||||
|
||||
total_docs = len(docs)
|
||||
|
||||
# Process and embed documents
|
||||
for idx, doc in tqdm(
|
||||
enumerate(docs),
|
||||
desc="Embedding 🦖",
|
||||
unit="docs",
|
||||
total=total_docs,
|
||||
bar_format="{l_bar}{bar}| Time Left: {remaining}",
|
||||
):
|
||||
try:
|
||||
# Update task status for progress tracking
|
||||
progress = int(((idx + 1) / total_docs) * 100)
|
||||
task_status.update_state(state="PROGRESS", meta={"current": progress})
|
||||
|
||||
# Add document to vector store
|
||||
add_text_to_store_with_retry(store, doc, source_id)
|
||||
except Exception as e:
|
||||
logging.error(f"Error embedding document {idx}: {e}")
|
||||
logging.info(f"Saving progress at document {idx} out of {total_docs}")
|
||||
store.save_local(folder_name)
|
||||
break
|
||||
|
||||
# Save the vector store
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
store.save_local(folder_name)
|
||||
logging.info("Vector store saved successfully.")
|
||||
@@ -1,75 +0,0 @@
|
||||
import os
|
||||
|
||||
from retry import retry
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
# from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
# from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
||||
# from langchain_community.embeddings import CohereEmbeddings
|
||||
|
||||
|
||||
@retry(tries=10, delay=60)
|
||||
def store_add_texts_with_retry(store, i, id):
|
||||
# add source_id to the metadata
|
||||
i.metadata["source_id"] = str(id)
|
||||
store.add_texts([i.page_content], metadatas=[i.metadata])
|
||||
# store_pine.add_texts([i.page_content], metadatas=[i.metadata])
|
||||
|
||||
|
||||
def call_openai_api(docs, folder_name, id, task_status):
|
||||
# Function to create a vector store from the documents and save it to disk
|
||||
|
||||
if not os.path.exists(f"{folder_name}"):
|
||||
os.makedirs(f"{folder_name}")
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
c1 = 0
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
docs_init = [docs[0]]
|
||||
docs.pop(0)
|
||||
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
docs_init=docs_init,
|
||||
source_id=f"{folder_name}",
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
else:
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=str(id),
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
store.delete_index()
|
||||
# Uncomment for MPNet embeddings
|
||||
# model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
# hf = HuggingFaceEmbeddings(model_name=model_name)
|
||||
# store = FAISS.from_documents(docs_test, hf)
|
||||
s1 = len(docs)
|
||||
for i in tqdm(
|
||||
docs,
|
||||
desc="Embedding 🦖",
|
||||
unit="docs",
|
||||
total=len(docs),
|
||||
bar_format="{l_bar}{bar}| Time Left: {remaining}",
|
||||
):
|
||||
try:
|
||||
task_status.update_state(
|
||||
state="PROGRESS", meta={"current": int((c1 / s1) * 100)}
|
||||
)
|
||||
store_add_texts_with_retry(store, i, id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Error on ", i)
|
||||
print("Saving progress")
|
||||
print(f"stopped at {c1} out of {len(docs)}")
|
||||
store.save_local(f"{folder_name}")
|
||||
break
|
||||
c1 += 1
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
store.save_local(f"{folder_name}")
|
||||
@@ -1,79 +0,0 @@
|
||||
import re
|
||||
from math import ceil
|
||||
from typing import List
|
||||
|
||||
import tiktoken
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
def separate_header_and_body(text):
|
||||
header_pattern = r"^(.*?\n){3}"
|
||||
match = re.match(header_pattern, text)
|
||||
header = match.group(0)
|
||||
body = text[len(header):]
|
||||
return header, body
|
||||
|
||||
|
||||
def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) -> List[Document]:
|
||||
docs = []
|
||||
current_group = None
|
||||
|
||||
for doc in documents:
|
||||
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
|
||||
|
||||
# Check if current group is empty or if the document can be added based on token count and matching metadata
|
||||
if (current_group is None or
|
||||
(len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and
|
||||
doc_len < min_tokens and
|
||||
current_group.extra_info == doc.extra_info)):
|
||||
if current_group is None:
|
||||
current_group = doc # Use the document directly to retain its metadata
|
||||
else:
|
||||
current_group.text += " " + doc.text # Append text to the current group
|
||||
else:
|
||||
docs.append(current_group)
|
||||
current_group = doc # Start a new group with the current document
|
||||
|
||||
if current_group is not None:
|
||||
docs.append(current_group)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def split_documents(documents: List[Document], max_tokens: int) -> List[Document]:
|
||||
docs = []
|
||||
for doc in documents:
|
||||
token_length = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
|
||||
if token_length <= max_tokens:
|
||||
docs.append(doc)
|
||||
else:
|
||||
header, body = separate_header_and_body(doc.text)
|
||||
if len(tiktoken.get_encoding("cl100k_base").encode(header)) > max_tokens:
|
||||
body = doc.text
|
||||
header = ""
|
||||
num_body_parts = ceil(token_length / max_tokens)
|
||||
part_length = ceil(len(body) / num_body_parts)
|
||||
body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)]
|
||||
for i, body_part in enumerate(body_parts):
|
||||
new_doc = Document(text=header + body_part.strip(),
|
||||
doc_id=f"{doc.doc_id}-{i}",
|
||||
embedding=doc.embedding,
|
||||
extra_info=doc.extra_info)
|
||||
docs.append(new_doc)
|
||||
return docs
|
||||
|
||||
|
||||
def group_split(documents: List[Document], max_tokens: int = 2000, min_tokens: int = 150, token_check: bool = True):
|
||||
if not token_check:
|
||||
return documents
|
||||
print("Grouping small documents")
|
||||
try:
|
||||
documents = group_documents(documents=documents, min_tokens=min_tokens, max_tokens=max_tokens)
|
||||
except Exception:
|
||||
print("Grouping failed, try running without token_check")
|
||||
print("Separating large documents")
|
||||
try:
|
||||
documents = split_documents(documents=documents, max_tokens=max_tokens)
|
||||
except Exception:
|
||||
print("Grouping failed, try running without token_check")
|
||||
return documents
|
||||
@@ -1,24 +1,24 @@
|
||||
anthropic==0.40.0
|
||||
boto3==1.34.153
|
||||
beautifulsoup4==4.12.3
|
||||
celery==5.3.6
|
||||
celery==5.4.0
|
||||
dataclasses-json==0.6.7
|
||||
docx2txt==0.8
|
||||
duckduckgo-search==6.3.0
|
||||
ebooklib==0.18
|
||||
elastic-transport==8.15.0
|
||||
elasticsearch==8.15.1
|
||||
elastic-transport==8.15.1
|
||||
elasticsearch==8.17.0
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
Flask==3.0.3
|
||||
faiss-cpu==1.8.0.post1
|
||||
faiss-cpu==1.9.0.post1
|
||||
flask-restx==1.3.0
|
||||
gTTS==2.3.2
|
||||
gunicorn==23.0.0
|
||||
html2text==2024.2.26
|
||||
javalang==0.13.0
|
||||
jinja2==3.1.4
|
||||
jinja2==3.1.5
|
||||
jiter==0.5.0
|
||||
jmespath==1.0.1
|
||||
joblib==1.4.2
|
||||
@@ -28,22 +28,22 @@ jsonschema==4.23.0
|
||||
jsonschema-spec==0.2.4
|
||||
jsonschema-specifications==2023.7.1
|
||||
kombu==5.4.2
|
||||
langchain==0.3.11
|
||||
langchain-community==0.3.11
|
||||
langchain-core==0.3.25
|
||||
langchain-openai==0.2.0
|
||||
langchain-text-splitters==0.3.0
|
||||
langsmith==0.2.3
|
||||
langchain==0.3.13
|
||||
langchain-community==0.3.13
|
||||
langchain-core==0.3.28
|
||||
langchain-openai==0.2.14
|
||||
langchain-text-splitters==0.3.4
|
||||
langsmith==0.2.6
|
||||
lazy-object-proxy==1.10.0
|
||||
lxml==5.3.0
|
||||
markupsafe==2.1.5
|
||||
marshmallow==3.22.0
|
||||
marshmallow==3.23.2
|
||||
mpmath==1.3.0
|
||||
multidict==6.1.0
|
||||
mypy-extensions==1.0.0
|
||||
networkx==3.3
|
||||
numpy==1.26.4
|
||||
openai==1.55.3
|
||||
numpy==2.2.1
|
||||
openai==1.58.1
|
||||
openapi-schema-validator==0.6.2
|
||||
openapi-spec-validator==0.6.0
|
||||
openapi3-parser==1.1.18
|
||||
@@ -68,13 +68,13 @@ python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
python-pptx==1.0.2
|
||||
qdrant-client==1.11.0
|
||||
redis==5.0.1
|
||||
redis==5.2.1
|
||||
referencing==0.30.2
|
||||
regex==2024.9.11
|
||||
requests==2.32.3
|
||||
retry==0.9.2
|
||||
sentence-transformers==3.0.1
|
||||
tiktoken==0.7.0
|
||||
sentence-transformers==3.3.1
|
||||
tiktoken==0.8.0
|
||||
tokenizers==0.21.0
|
||||
torch==2.4.1
|
||||
tqdm==4.66.5
|
||||
@@ -86,4 +86,4 @@ urllib3==2.2.3
|
||||
vine==5.1.0
|
||||
wcwidth==0.2.13
|
||||
werkzeug==3.1.3
|
||||
yarl==1.11.1
|
||||
yarl==1.18.3
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import num_tokens_from_string
|
||||
from langchain_community.tools import BraveSearch
|
||||
|
||||
|
||||
@@ -73,15 +72,8 @@ class BraveRetSearch(BaseRetriever):
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.tools.agent import Agent
|
||||
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class ClassicRAG(BaseRetriever):
|
||||
@@ -20,7 +20,7 @@ class ClassicRAG(BaseRetriever):
|
||||
user_api_key=None,
|
||||
):
|
||||
self.question = question
|
||||
self.vectorstore = source['active_docs'] if 'active_docs' in source else None
|
||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
@@ -73,15 +73,8 @@ class ClassicRAG(BaseRetriever):
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
@@ -89,17 +82,23 @@ class ClassicRAG(BaseRetriever):
|
||||
{"role": "system", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
# llm = LLMCreator.create_llm(
|
||||
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
# )
|
||||
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
agent = Agent(
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
)
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
completion = agent.gen(messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
|
||||
def get_params(self):
|
||||
return {
|
||||
"question": self.question,
|
||||
@@ -109,5 +108,5 @@ class ClassicRAG(BaseRetriever):
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import num_tokens_from_string
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
@@ -89,16 +88,9 @@ class DuckDuckSearch(BaseRetriever):
|
||||
for doc in docs:
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
if len(self.chat_history) > 1:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
|
||||
149
application/tools/agent.py
Normal file
149
application/tools/agent.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.tools.tool_manager import ToolManager
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
|
||||
# Initialize the LLM with the provided parameters
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name, api_key=api_key, user_api_key=user_api_key
|
||||
)
|
||||
self.gpt_model = gpt_model
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
|
||||
return tools_by_id
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": {
|
||||
**action["parameters"],
|
||||
"properties": {
|
||||
k: {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key != "filled_by_llm" and key != "value"
|
||||
}
|
||||
for k, v in action["parameters"]["properties"].items()
|
||||
if v.get("filled_by_llm", False)
|
||||
},
|
||||
"required": [
|
||||
key
|
||||
for key in action["parameters"]["required"]
|
||||
if key in action["parameters"]["properties"]
|
||||
and action["parameters"]["properties"][key].get(
|
||||
"filled_by_llm", False
|
||||
)
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
for action in tool["actions"]
|
||||
if action["active"]
|
||||
]
|
||||
|
||||
def _execute_tool_action(self, tools_dict, call):
|
||||
call_id = call.id
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = next(
|
||||
action for action in tool_data["actions"] if action["name"] == action_name
|
||||
)
|
||||
|
||||
for param, details in action_data["parameters"]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
call_args[param] = details["value"]
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
return tool.execute_action(action_name, **call_args), call_id
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
tools_dict = self._get_user_tools()
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
return
|
||||
if resp.message.content:
|
||||
yield resp.message.content
|
||||
return
|
||||
|
||||
while resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
tool_response, call_id = self._execute_tool_action(tools_dict, call)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call.id,
|
||||
}
|
||||
)
|
||||
# Generate a new response from the LLM after processing tools
|
||||
resp = self.llm.gen(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
|
||||
# If no tool calls are needed, generate the final response
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
elif resp.message.content:
|
||||
yield resp.message.content
|
||||
else:
|
||||
completion = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
for line in completion:
|
||||
yield line
|
||||
|
||||
return
|
||||
|
||||
def gen(self, messages):
|
||||
# Generate initial response from the LLM
|
||||
if self.llm.supports_tools():
|
||||
resp = self._simple_tool_agent(messages)
|
||||
for line in resp:
|
||||
yield line
|
||||
else:
|
||||
resp = self.llm.gen_stream(model=self.gpt_model, messages=messages)
|
||||
for line in resp:
|
||||
yield line
|
||||
21
application/tools/base.py
Normal file
21
application/tools/base.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_actions_metadata(self):
|
||||
"""
|
||||
Returns a list of JSON objects describing the actions supported by the tool.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_config_requirements(self):
|
||||
"""
|
||||
Returns a dictionary describing the configuration requirements for the tool.
|
||||
"""
|
||||
pass
|
||||
77
application/tools/implementations/cryptoprice.py
Normal file
77
application/tools/implementations/cryptoprice.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import requests
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class CryptoPriceTool(Tool):
|
||||
"""
|
||||
CryptoPrice
|
||||
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {"cryptoprice_get": self._get_price}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def _get_price(self, symbol, currency):
|
||||
"""
|
||||
Fetches the current price of a given cryptocurrency symbol in the specified currency.
|
||||
Example:
|
||||
symbol = "BTC"
|
||||
currency = "USD"
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# data will be like {"USD": <price>} if the call is successful
|
||||
if currency.upper() in data:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"price": data[currency.upper()],
|
||||
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": "Failed to retrieve price.",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)",
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)",
|
||||
},
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
# No specific configuration needed for this tool as it just queries a public endpoint
|
||||
return {}
|
||||
86
application/tools/implementations/telegram.py
Normal file
86
application/tools/implementations/telegram.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import requests
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class TelegramTool(Tool):
|
||||
"""
|
||||
Telegram Bot
|
||||
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
|
||||
Requires a bot token and chat ID for configuration
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.token = config.get("token", "")
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"telegram_send_message": self._send_message,
|
||||
"telegram_send_image": self._send_image,
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def _send_message(self, text, chat_id):
|
||||
print(f"Sending message: {text}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def _send_image(self, image_url, chat_id):
|
||||
print(f"Sending image: {image_url}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "telegram_send_message",
|
||||
"description": "Send a notification to Telegram chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification",
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Chat ID to send the notification to",
|
||||
},
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "telegram_send_image",
|
||||
"description": "Send an image to the Telegram chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the image to send",
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Chat ID to send the image to",
|
||||
},
|
||||
},
|
||||
"required": ["image_url"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"token": {"type": "string", "description": "Bot token for authentication"},
|
||||
}
|
||||
46
application/tools/tool_manager.py
Normal file
46
application/tools/tool_manager.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import pkgutil
|
||||
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.tools = {}
|
||||
self.load_tools()
|
||||
|
||||
def load_tools(self):
|
||||
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
|
||||
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
|
||||
if name == "base" or name.startswith("__"):
|
||||
continue
|
||||
module = importlib.import_module(
|
||||
f"application.tools.implementations.{name}"
|
||||
)
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(
|
||||
f"application.tools.implementations.{tool_name}"
|
||||
)
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
metadata = []
|
||||
for tool in self.tools.values():
|
||||
metadata.extend(tool.get_actions_metadata())
|
||||
return metadata
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.utils import num_tokens_from_string, num_tokens_from_object_or_list
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
@@ -21,11 +21,16 @@ def update_token_usage(user_api_key, token_usage):
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, **kwargs):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||
result = func(self, model, messages, stream, **kwargs)
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
if message["content"]:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
# check if result is a string
|
||||
if isinstance(result, str):
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
else:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
return result
|
||||
|
||||
@@ -33,11 +38,11 @@ def gen_token_usage(func):
|
||||
|
||||
|
||||
def stream_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, **kwargs):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||
batch = []
|
||||
result = func(self, model, messages, stream, **kwargs)
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
|
||||
@@ -15,9 +15,21 @@ def get_encoding():
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
encoding = get_encoding()
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
if isinstance(string, str):
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
else:
|
||||
return 0
|
||||
|
||||
def num_tokens_from_object_or_list(thing):
|
||||
if isinstance(thing, list):
|
||||
return sum([num_tokens_from_object_or_list(x) for x in thing])
|
||||
elif isinstance(thing, dict):
|
||||
return sum([num_tokens_from_object_or_list(x) for x in thing.values()])
|
||||
elif isinstance(thing, str):
|
||||
return num_tokens_from_string(thing)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def count_tokens_docs(docs):
|
||||
docs_content = ""
|
||||
@@ -46,3 +58,40 @@ def check_required_fields(data, required_fields):
|
||||
def get_hash(data):
|
||||
return hashlib.md5(data.encode()).hexdigest()
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
"""
|
||||
Limits chat history based on token count.
|
||||
Returns a list of messages that fit within the token limit.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
max_token_limit = (
|
||||
max_token_limit
|
||||
if max_token_limit and
|
||||
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if not history:
|
||||
return []
|
||||
|
||||
tokens_current_history = 0
|
||||
trimmed_history = []
|
||||
|
||||
for message in reversed(history):
|
||||
if "prompt" in message and "response" in message:
|
||||
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(
|
||||
message["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < max_token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
trimmed_history.insert(0, message)
|
||||
else:
|
||||
break
|
||||
|
||||
return trimmed_history
|
||||
|
||||
@@ -12,10 +12,10 @@ from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.parser.file.bulk import SimpleDirectoryReader
|
||||
from application.parser.open_ai_func import call_openai_api
|
||||
from application.parser.embedding_pipeline import embed_and_store_documents
|
||||
from application.parser.remote.remote_creator import RemoteCreator
|
||||
from application.parser.schema.base import Document
|
||||
from application.parser.token_func import group_split
|
||||
from application.parser.chunking import Chunker
|
||||
from application.utils import count_tokens_docs
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
@@ -126,7 +126,6 @@ def ingest_worker(
|
||||
limit = None
|
||||
exclude = True
|
||||
sample = False
|
||||
token_check = True
|
||||
full_path = os.path.join(directory, user, name_job)
|
||||
|
||||
logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": name_job})
|
||||
@@ -153,17 +152,19 @@ def ingest_worker(
|
||||
exclude_hidden=exclude,
|
||||
file_metadata=metadata_from_filename,
|
||||
).load_data()
|
||||
raw_docs = group_split(
|
||||
documents=raw_docs,
|
||||
min_tokens=MIN_TOKENS,
|
||||
|
||||
chunker = Chunker(
|
||||
chunking_strategy="classic_chunk",
|
||||
max_tokens=MAX_TOKENS,
|
||||
token_check=token_check,
|
||||
min_tokens=MIN_TOKENS,
|
||||
duplicate_headers=False
|
||||
)
|
||||
raw_docs = chunker.chunk(documents=raw_docs)
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
id = ObjectId()
|
||||
|
||||
call_openai_api(docs, full_path, id, self)
|
||||
embed_and_store_documents(docs, full_path, id, self)
|
||||
tokens = count_tokens_docs(docs)
|
||||
self.update_state(state="PROGRESS", meta={"current": 100})
|
||||
|
||||
@@ -203,7 +204,6 @@ def remote_worker(
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
):
|
||||
token_check = True
|
||||
full_path = os.path.join(directory, user, name_job)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
@@ -217,21 +217,23 @@ def remote_worker(
|
||||
remote_loader = RemoteCreator.create_loader(loader)
|
||||
raw_docs = remote_loader.load_data(source_data)
|
||||
|
||||
docs = group_split(
|
||||
documents=raw_docs,
|
||||
min_tokens=MIN_TOKENS,
|
||||
chunker = Chunker(
|
||||
chunking_strategy="classic_chunk",
|
||||
max_tokens=MAX_TOKENS,
|
||||
token_check=token_check,
|
||||
min_tokens=MIN_TOKENS,
|
||||
duplicate_headers=False
|
||||
)
|
||||
docs = chunker.chunk(documents=raw_docs)
|
||||
|
||||
tokens = count_tokens_docs(docs)
|
||||
if operation_mode == "upload":
|
||||
id = ObjectId()
|
||||
call_openai_api(docs, full_path, id, self)
|
||||
embed_and_store_documents(docs, full_path, id, self)
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id or not ObjectId.is_valid(doc_id):
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = ObjectId(doc_id)
|
||||
call_openai_api(docs, full_path, id, self)
|
||||
embed_and_store_documents(docs, full_path, id, self)
|
||||
self.update_state(state="PROGRESS", meta={"current": 100})
|
||||
|
||||
file_data = {
|
||||
|
||||
88
docs/package-lock.json
generated
88
docs/package-lock.json
generated
@@ -8,7 +8,7 @@
|
||||
"dependencies": {
|
||||
"@vercel/analytics": "^1.1.1",
|
||||
"docsgpt-react": "^0.4.8",
|
||||
"next": "^14.2.12",
|
||||
"next": "^14.2.20",
|
||||
"nextra": "^2.13.2",
|
||||
"nextra-theme-docs": "^2.13.2",
|
||||
"react": "^18.2.0",
|
||||
@@ -931,14 +931,14 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-14.2.12.tgz",
|
||||
"integrity": "sha512-3fP29GIetdwVIfIRyLKM7KrvJaqepv+6pVodEbx0P5CaMLYBtx+7eEg8JYO5L9sveJO87z9eCReceZLi0hxO1Q=="
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-14.2.20.tgz",
|
||||
"integrity": "sha512-JfDpuOCB0UBKlEgEy/H6qcBSzHimn/YWjUHzKl1jMeUO+QVRdzmTTl8gFJaNO87c8DXmVKhFCtwxQ9acqB3+Pw=="
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.12.tgz",
|
||||
"integrity": "sha512-crHJ9UoinXeFbHYNok6VZqjKnd8rTd7K3Z2zpyzF1ch7vVNKmhjv/V7EHxep3ILoN8JB9AdRn/EtVVyG9AkCXw==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.20.tgz",
|
||||
"integrity": "sha512-WDfq7bmROa5cIlk6ZNonNdVhKmbCv38XteVFYsxea1vDJt3SnYGgxLGMTXQNfs5OkFvAhmfKKrwe7Y0Hs+rWOg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -951,9 +951,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.12.tgz",
|
||||
"integrity": "sha512-JbEaGbWq18BuNBO+lCtKfxl563Uw9oy2TodnN2ioX00u7V1uzrsSUcg3Ep9ce+P0Z9es+JmsvL2/rLphz+Frcw==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.20.tgz",
|
||||
"integrity": "sha512-XIQlC+NAmJPfa2hruLvr1H1QJJeqOTDV+v7tl/jIdoFvqhoihvSNykLU/G6NMgoeo+e/H7p/VeWSOvMUHKtTIg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -966,9 +966,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.12.tgz",
|
||||
"integrity": "sha512-qBy7OiXOqZrdp88QEl2H4fWalMGnSCrr1agT/AVDndlyw2YJQA89f3ttR/AkEIP9EkBXXeGl6cC72/EZT5r6rw==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.20.tgz",
|
||||
"integrity": "sha512-pnzBrHTPXIMm5QX3QC8XeMkpVuoAYOmyfsO4VlPn+0NrHraNuWjdhe+3xLq01xR++iCvX+uoeZmJDKcOxI201Q==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -981,9 +981,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.12.tgz",
|
||||
"integrity": "sha512-EfD9L7o9biaQxjwP1uWXnk3vYZi64NVcKUN83hpVkKocB7ogJfyH2r7o1pPnMtir6gHZiGCeHKagJ0yrNSLNHw==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.20.tgz",
|
||||
"integrity": "sha512-WhJJAFpi6yqmUx1momewSdcm/iRXFQS0HU2qlUGlGE/+98eu7JWLD5AAaP/tkK1mudS/rH2f9E3WCEF2iYDydQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -996,9 +996,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.12.tgz",
|
||||
"integrity": "sha512-iQ+n2pxklJew9IpE47hE/VgjmljlHqtcD5UhZVeHICTPbLyrgPehaKf2wLRNjYH75udroBNCgrSSVSVpAbNoYw==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.20.tgz",
|
||||
"integrity": "sha512-ao5HCbw9+iG1Kxm8XsGa3X174Ahn17mSYBQlY6VGsdsYDAbz/ZP13wSLfvlYoIDn1Ger6uYA+yt/3Y9KTIupRg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1011,9 +1011,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.12.tgz",
|
||||
"integrity": "sha512-rFkUkNwcQ0ODn7cxvcVdpHlcOpYxMeyMfkJuzaT74xjAa5v4fxP4xDk5OoYmPi8QNLDs3UgZPMSBmpBuv9zKWA==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.20.tgz",
|
||||
"integrity": "sha512-CXm/kpnltKTT7945np6Td3w7shj/92TMRPyI/VvveFe8+YE+/YOJ5hyAWK5rpx711XO1jBCgXl211TWaxOtkaA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1026,9 +1026,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.12.tgz",
|
||||
"integrity": "sha512-PQFYUvwtHs/u0K85SG4sAdDXYIPXpETf9mcEjWc0R4JmjgMKSDwIU/qfZdavtP6MPNiMjuKGXHCtyhR/M5zo8g==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.20.tgz",
|
||||
"integrity": "sha512-upJn2HGQgKNDbXVfIgmqT2BN8f3z/mX8ddoyi1I565FHbfowVK5pnMEwauvLvaJf4iijvuKq3kw/b6E9oIVRWA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1041,9 +1041,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-ia32-msvc": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.12.tgz",
|
||||
"integrity": "sha512-FAj2hMlcbeCV546eU2tEv41dcJb4NeqFlSXU/xL/0ehXywHnNpaYajOUvn3P8wru5WyQe6cTZ8fvckj/2XN4Vw==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.20.tgz",
|
||||
"integrity": "sha512-igQW/JWciTGJwj3G1ipalD2V20Xfx3ywQy17IV0ciOUBbFhNfyU1DILWsTi32c8KmqgIDviUEulW/yPb2FF90w==",
|
||||
"cpu": [
|
||||
"ia32"
|
||||
],
|
||||
@@ -1056,9 +1056,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.12.tgz",
|
||||
"integrity": "sha512-yu8QvV53sBzoIVRHsxCHqeuS8jYq6Lrmdh0briivuh+Brsp6xjg80MAozUsBTAV9KNmY08KlX0KYTWz1lbPzEg==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.20.tgz",
|
||||
"integrity": "sha512-AFmqeLW6LtxeFTuoB+MXFeM5fm5052i3MU6xD0WzJDOwku6SkZaxb1bxjBaRC8uNqTRTSPl0yMFtjNowIVI67w==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -6759,11 +6759,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "14.2.12",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-14.2.12.tgz",
|
||||
"integrity": "sha512-cDOtUSIeoOvt1skKNihdExWMTybx3exnvbFbb9ecZDIxlvIbREQzt9A5Km3Zn3PfU+IFjyYGsHS+lN9VInAGKA==",
|
||||
"version": "14.2.20",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-14.2.20.tgz",
|
||||
"integrity": "sha512-yPvIiWsiyVYqJlSQxwmzMIReXn5HxFNq4+tlVQ812N1FbvhmE+fDpIAD7bcS2mGYQwPJ5vAsQouyme2eKsxaug==",
|
||||
"dependencies": {
|
||||
"@next/env": "14.2.12",
|
||||
"@next/env": "14.2.20",
|
||||
"@swc/helpers": "0.5.5",
|
||||
"busboy": "1.6.0",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
@@ -6778,15 +6778,15 @@
|
||||
"node": ">=18.17.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "14.2.12",
|
||||
"@next/swc-darwin-x64": "14.2.12",
|
||||
"@next/swc-linux-arm64-gnu": "14.2.12",
|
||||
"@next/swc-linux-arm64-musl": "14.2.12",
|
||||
"@next/swc-linux-x64-gnu": "14.2.12",
|
||||
"@next/swc-linux-x64-musl": "14.2.12",
|
||||
"@next/swc-win32-arm64-msvc": "14.2.12",
|
||||
"@next/swc-win32-ia32-msvc": "14.2.12",
|
||||
"@next/swc-win32-x64-msvc": "14.2.12"
|
||||
"@next/swc-darwin-arm64": "14.2.20",
|
||||
"@next/swc-darwin-x64": "14.2.20",
|
||||
"@next/swc-linux-arm64-gnu": "14.2.20",
|
||||
"@next/swc-linux-arm64-musl": "14.2.20",
|
||||
"@next/swc-linux-x64-gnu": "14.2.20",
|
||||
"@next/swc-linux-x64-musl": "14.2.20",
|
||||
"@next/swc-win32-arm64-msvc": "14.2.20",
|
||||
"@next/swc-win32-ia32-msvc": "14.2.20",
|
||||
"@next/swc-win32-x64-msvc": "14.2.20"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@opentelemetry/api": "^1.1.0",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"dependencies": {
|
||||
"@vercel/analytics": "^1.1.1",
|
||||
"docsgpt-react": "^0.4.8",
|
||||
"next": "^14.2.12",
|
||||
"next": "^14.2.20",
|
||||
"nextra": "^2.13.2",
|
||||
"nextra-theme-docs": "^2.13.2",
|
||||
"react": "^18.2.0",
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
import React from 'react'
|
||||
import styled, { ThemeProvider } from 'styled-components';
|
||||
import React from 'react';
|
||||
import styled, { ThemeProvider, createGlobalStyle } from 'styled-components';
|
||||
import { WidgetCore } from './DocsGPTWidget';
|
||||
import { SearchBarProps } from '@/types';
|
||||
import { getSearchResults } from '../requests/searchAPI'
|
||||
import { getSearchResults } from '../requests/searchAPI';
|
||||
import { Result } from '@/types';
|
||||
import MarkdownIt from 'markdown-it';
|
||||
import { getOS, preprocessSearchResultsToHTML } from '../utils/helper'
|
||||
import { getOS, processMarkdownString } from '../utils/helper';
|
||||
import DOMPurify from 'dompurify';
|
||||
import {
|
||||
CodeIcon,
|
||||
TextAlignLeftIcon,
|
||||
HeadingIcon,
|
||||
ReaderIcon,
|
||||
ListBulletIcon,
|
||||
QuoteIcon
|
||||
} from '@radix-ui/react-icons';
|
||||
const themes = {
|
||||
dark: {
|
||||
bg: '#000',
|
||||
text: '#fff',
|
||||
text: '#EDEDED',
|
||||
primary: {
|
||||
text: "#FAFAFA",
|
||||
bg: '#111111'
|
||||
@@ -21,7 +30,7 @@ const themes = {
|
||||
},
|
||||
light: {
|
||||
bg: '#fff',
|
||||
text: '#000',
|
||||
text: '#171717',
|
||||
primary: {
|
||||
text: "#222327",
|
||||
bg: "#fff"
|
||||
@@ -33,16 +42,30 @@ const themes = {
|
||||
}
|
||||
}
|
||||
|
||||
const GlobalStyle = createGlobalStyle`
|
||||
.highlight {
|
||||
color:#007EE6;
|
||||
}
|
||||
`;
|
||||
|
||||
const loadGeistFont = () => {
|
||||
const link = document.createElement('link');
|
||||
link.href = 'https://fonts.googleapis.com/css2?family=Geist:wght@100..900&display=swap';
|
||||
link.rel = 'stylesheet';
|
||||
document.head.appendChild(link);
|
||||
};
|
||||
|
||||
const Main = styled.div`
|
||||
all:initial;
|
||||
font-family: sans-serif;
|
||||
all: initial;
|
||||
font-family: 'Geist', sans-serif;
|
||||
`
|
||||
const TextField = styled.input<{ inputWidth: string }>`
|
||||
padding: 6px 6px;
|
||||
const SearchButton = styled.button<{ inputWidth: string }>`
|
||||
padding: 6px 6px;
|
||||
font-family: inherit;
|
||||
width: ${({ inputWidth }) => inputWidth};
|
||||
border-radius: 8px;
|
||||
display: inline;
|
||||
color: ${props => props.theme.primary.text};
|
||||
color: ${props => props.theme.secondary.text};
|
||||
outline: none;
|
||||
border: none;
|
||||
background-color: ${props => props.theme.secondary.bg};
|
||||
@@ -50,14 +73,8 @@ const TextField = styled.input<{ inputWidth: string }>`
|
||||
-moz-appearance: none;
|
||||
appearance: none;
|
||||
transition: background-color 128ms linear;
|
||||
&:focus {
|
||||
outline: none;
|
||||
box-shadow:
|
||||
0px 0px 0px 2px rgba(0, 109, 199),
|
||||
0px 0px 6px rgb(0, 90, 163),
|
||||
0px 2px 6px rgba(0, 0, 0, 0.1) ;
|
||||
background-color: ${props => props.theme.primary.bg};
|
||||
}
|
||||
text-align: left;
|
||||
cursor: pointer;
|
||||
`
|
||||
|
||||
const Container = styled.div`
|
||||
@@ -65,61 +82,122 @@ const Container = styled.div`
|
||||
display: inline-block;
|
||||
`
|
||||
const SearchResults = styled.div`
|
||||
position: absolute;
|
||||
display: block;
|
||||
position: fixed;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
background-color: ${props => props.theme.primary.bg};
|
||||
border: 1px solid rgba(0, 0, 0, .1);
|
||||
border-radius: 12px;
|
||||
padding: 8px;
|
||||
width: 576px;
|
||||
min-width: 96%;
|
||||
border: 1px solid ${props => props.theme.secondary.bg};
|
||||
border-radius: 15px;
|
||||
padding: 8px 0px 8px 0px;
|
||||
width: 792px;
|
||||
max-width: 90vw;
|
||||
height: 396px;
|
||||
z-index: 100;
|
||||
height: 25vh;
|
||||
overflow-y: auto;
|
||||
top: 32px;
|
||||
left: 50%;
|
||||
top: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
color: ${props => props.theme.primary.text};
|
||||
scrollbar-color: lab(48.438 0 0 / 0.4) rgba(0, 0, 0, 0);
|
||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
backdrop-filter: blur(16px);
|
||||
box-sizing: border-box;
|
||||
|
||||
@media only screen and (max-width: 768px) {
|
||||
height: 80vh;
|
||||
width: 90vw;
|
||||
}
|
||||
`;
|
||||
|
||||
const SearchResultsScroll = styled.div`
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
scrollbar-gutter: stable;
|
||||
scrollbar-width: thin;
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05), 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
backdrop-filter: blur(16px);
|
||||
@media only screen and (max-width: 768px) {
|
||||
max-height: 100vh;
|
||||
max-width: 80vw;
|
||||
overflow: auto;
|
||||
scrollbar-color: #383838 transparent;
|
||||
padding: 0 16px;
|
||||
`;
|
||||
|
||||
const IconTitleWrapper = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
|
||||
.element-icon{
|
||||
margin: 4px;
|
||||
}
|
||||
`
|
||||
`;
|
||||
|
||||
const Title = styled.h3`
|
||||
font-size: 14px;
|
||||
font-size: 15px;
|
||||
font-weight: 400;
|
||||
color: ${props => props.theme.primary.text};
|
||||
opacity: 0.8;
|
||||
padding-bottom: 6px;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
border-bottom: 1px solid ${(props) => props.theme.secondary.text};
|
||||
`
|
||||
margin: 0;
|
||||
overflow-wrap: break-word;
|
||||
white-space: normal;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
`;
|
||||
const ContentWrapper = styled.div`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
`;
|
||||
const Content = styled.div`
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
|
||||
display: flex;
|
||||
margin-left: 8px;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
padding: 4px 0px 0px 12px;
|
||||
font-size: 15px;
|
||||
color: ${props => props.theme.primary.text};
|
||||
line-height: 1.6;
|
||||
border-left: 2px solid #585858;
|
||||
overflow: hidden;
|
||||
`
|
||||
const ContentSegment = styled.div`
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 8px;
|
||||
padding-right: 16px;
|
||||
overflow-wrap: break-word;
|
||||
white-space: normal;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
`
|
||||
|
||||
const ResultWrapper = styled.div`
|
||||
padding: 4px 8px 4px 8px;
|
||||
border-radius: 8px;
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
padding: 12px 16px 0 16px;
|
||||
cursor: pointer;
|
||||
&.contains-source:hover{
|
||||
background-color: ${props => props.theme.primary.bg};
|
||||
font-family: 'Geist',sans-serif;
|
||||
transition: background-color 0.2s;
|
||||
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-word;
|
||||
white-space: normal;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
|
||||
&.contains-source:hover {
|
||||
background-color: rgba(0, 92, 197, 0.15);
|
||||
${Title} {
|
||||
color: rgb(0, 126, 230);
|
||||
}
|
||||
color: rgb(0, 126, 230);
|
||||
}
|
||||
}
|
||||
`
|
||||
const Markdown = styled.div`
|
||||
line-height:20px;
|
||||
font-size: 12px;
|
||||
line-height:18px;
|
||||
font-size: 11px;
|
||||
white-space: pre-wrap;
|
||||
pre {
|
||||
padding: 8px;
|
||||
width: 90%;
|
||||
font-size: 12px;
|
||||
font-size: 11px;
|
||||
border-radius: 6px;
|
||||
overflow-x: auto;
|
||||
background-color: #1B1C1F;
|
||||
@@ -127,7 +205,7 @@ white-space: pre-wrap;
|
||||
}
|
||||
|
||||
h1,h2 {
|
||||
font-size: 16px;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: ${(props) => props.theme.text};
|
||||
opacity: 0.8;
|
||||
@@ -135,20 +213,20 @@ white-space: pre-wrap;
|
||||
|
||||
|
||||
h3 {
|
||||
font-size: 14px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
p {
|
||||
margin: 0px;
|
||||
line-height: 1.35rem;
|
||||
font-size: 12px;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
code:not(pre code) {
|
||||
border-radius: 6px;
|
||||
padding: 2px 2px;
|
||||
margin: 2px;
|
||||
font-size: 10px;
|
||||
font-size: 9px;
|
||||
display: inline;
|
||||
background-color: #646464;
|
||||
color: #fff ;
|
||||
@@ -197,76 +275,131 @@ const Loader = styled.div`
|
||||
const NoResults = styled.div`
|
||||
margin-top: 2rem;
|
||||
text-align: center;
|
||||
font-size: 1rem;
|
||||
font-size: 14px;
|
||||
color: #888;
|
||||
`;
|
||||
const InfoButton = styled.button`
|
||||
cursor: pointer;
|
||||
padding: 10px 4px 10px 4px;
|
||||
display: block;
|
||||
width: 100%;
|
||||
color: inherit;
|
||||
const AskAIButton = styled.button`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: flex-start;
|
||||
gap: 12px;
|
||||
width: calc(100% - 32px);
|
||||
margin: 0 16px 16px 16px;
|
||||
box-sizing: border-box;
|
||||
height: 50px;
|
||||
padding: 8px 24px;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
background-color: ${(props) => props.theme.bg};
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
margin-bottom: 8px;
|
||||
border:1px solid ${(props) => props.theme.secondary.text};
|
||||
background-color: ${props => props.theme.secondary.bg};
|
||||
color: ${props => props.theme.text};
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s, box-shadow 0.2s;
|
||||
font-size: 16px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
|
||||
&:hover {
|
||||
opacity: 0.8;
|
||||
}
|
||||
`
|
||||
const SearchHeader = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: 12px;
|
||||
padding-bottom: 12px;
|
||||
border-bottom: 1px solid ${props => props.theme.secondary.bg};
|
||||
`
|
||||
|
||||
const TextField = styled.input`
|
||||
width: calc(100% - 32px);
|
||||
margin: 0 16px;
|
||||
padding: 12px 16px;
|
||||
border: none;
|
||||
background-color: transparent;
|
||||
color: ${props => props.theme.text};
|
||||
font-size: 20px;
|
||||
font-weight: 400;
|
||||
outline: none;
|
||||
|
||||
&:focus {
|
||||
border-color: none;
|
||||
}
|
||||
`
|
||||
|
||||
const EscapeInstruction = styled.kbd`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin: 12px 16px 0;
|
||||
padding: 4px 8px;
|
||||
border-radius: 4px;
|
||||
background-color: transparent;
|
||||
border: 1px solid ${props => props.theme.secondary.text};
|
||||
color: ${props => props.theme.text};
|
||||
font-size: 12px;
|
||||
font-family: 'Geist', sans-serif;
|
||||
white-space: nowrap;
|
||||
cursor: pointer;
|
||||
width: fit-content;
|
||||
&:hover {
|
||||
background-color: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
`
|
||||
export const SearchBar = ({
|
||||
apiKey = "74039c6d-bff7-44ce-ae55-2973cbf13837",
|
||||
apiHost = "https://gptcloud.arc53.com",
|
||||
theme = "dark",
|
||||
placeholder = "Search or Ask AI...",
|
||||
width = "256px"
|
||||
width = "256px",
|
||||
buttonText = "Search here"
|
||||
}: SearchBarProps) => {
|
||||
const [input, setInput] = React.useState<string>("");
|
||||
const [loading, setLoading] = React.useState<boolean>(false);
|
||||
const [isWidgetOpen, setIsWidgetOpen] = React.useState<boolean>(false);
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
const containerRef = React.useRef<HTMLInputElement>(null);
|
||||
const [isResultVisible, setIsResultVisible] = React.useState<boolean>(true);
|
||||
const [isResultVisible, setIsResultVisible] = React.useState<boolean>(false);
|
||||
const [results, setResults] = React.useState<Result[]>([]);
|
||||
const debounceTimeout = React.useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const abortControllerRef = React.useRef<AbortController | null>(null)
|
||||
const abortControllerRef = React.useRef<AbortController | null>(null);
|
||||
const browserOS = getOS();
|
||||
function isTouchDevice() {
|
||||
return 'ontouchstart' in window;
|
||||
}
|
||||
const isTouch = isTouchDevice();
|
||||
const isTouch = 'ontouchstart' in window;
|
||||
|
||||
const getKeyboardInstruction = () => {
|
||||
if (isResultVisible) return "Enter"
|
||||
if (browserOS === 'mac')
|
||||
return "⌘ K"
|
||||
else
|
||||
return "Ctrl K"
|
||||
}
|
||||
if (isResultVisible) return "Enter";
|
||||
return browserOS === 'mac' ? '⌘ + K' : 'Ctrl + K';
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
const handleFocusSearch = (event: KeyboardEvent) => {
|
||||
loadGeistFont()
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (containerRef.current && !containerRef.current.contains(event.target as Node)) {
|
||||
setIsResultVisible(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (
|
||||
((browserOS === 'win' || browserOS === 'linux') && event.ctrlKey && event.key === 'k') ||
|
||||
(browserOS === 'mac' && event.metaKey && event.key === 'k')
|
||||
) {
|
||||
event.preventDefault();
|
||||
inputRef.current?.focus();
|
||||
}
|
||||
}
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
containerRef.current &&
|
||||
!containerRef.current.contains(event.target as Node)
|
||||
) {
|
||||
setIsResultVisible(true);
|
||||
} else if (event.key === 'Escape') {
|
||||
setIsResultVisible(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
document.addEventListener('mousedown', handleClickOutside);
|
||||
document.addEventListener('keydown', handleFocusSearch);
|
||||
document.addEventListener('keydown', handleKeyDown);
|
||||
return () => {
|
||||
setIsResultVisible(true);
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
document.removeEventListener('keydown', handleKeyDown);
|
||||
};
|
||||
}, [])
|
||||
}, []);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (!input) {
|
||||
setResults([]);
|
||||
@@ -291,8 +424,6 @@ export const SearchBar = ({
|
||||
}, 500);
|
||||
|
||||
return () => {
|
||||
console.log(results);
|
||||
|
||||
abortController.abort();
|
||||
clearTimeout(debounceTimeout.current ?? undefined);
|
||||
};
|
||||
@@ -304,73 +435,106 @@ export const SearchBar = ({
|
||||
openWidget();
|
||||
}
|
||||
};
|
||||
|
||||
const openWidget = () => {
|
||||
setIsWidgetOpen(true);
|
||||
setIsResultVisible(false)
|
||||
}
|
||||
setIsResultVisible(false);
|
||||
};
|
||||
|
||||
const handleClose = () => {
|
||||
setIsWidgetOpen(false);
|
||||
}
|
||||
const md = new MarkdownIt();
|
||||
setIsResultVisible(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<ThemeProvider theme={{ ...themes[theme] }}>
|
||||
<Main>
|
||||
<GlobalStyle />
|
||||
<Container ref={containerRef}>
|
||||
<TextField
|
||||
spellCheck={false}
|
||||
<SearchButton
|
||||
onClick={() => setIsResultVisible(true)}
|
||||
inputWidth={width}
|
||||
onFocus={() => setIsResultVisible(true)}
|
||||
ref={inputRef}
|
||||
onSubmit={() => setIsWidgetOpen(true)}
|
||||
onKeyDown={(e) => handleKeyDown(e)}
|
||||
placeholder={placeholder}
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
/>
|
||||
>
|
||||
{buttonText}
|
||||
</SearchButton>
|
||||
{
|
||||
input.length > 0 && isResultVisible && (
|
||||
isResultVisible && (
|
||||
<SearchResults>
|
||||
<InfoButton onClick={openWidget}>
|
||||
{
|
||||
isTouch ?
|
||||
"Ask the AI" :
|
||||
<>
|
||||
Press <span style={{ fontSize: "16px" }}>↵</span> Enter to ask the AI
|
||||
</>
|
||||
}
|
||||
</InfoButton>
|
||||
{!loading ?
|
||||
(results.length > 0 ?
|
||||
results.map((res, key) => {
|
||||
const containsSource = res.source !== 'local';
|
||||
const filteredResults = preprocessSearchResultsToHTML(res.text,input)
|
||||
if (filteredResults)
|
||||
return (
|
||||
<ResultWrapper
|
||||
key={key}
|
||||
onClick={() => {
|
||||
if (!containsSource) return;
|
||||
window.open(res.source, '_blank', 'noopener, noreferrer')
|
||||
}}
|
||||
className={containsSource ? "contains-source" : ""}>
|
||||
<Title>{res.title}</Title>
|
||||
<Content>
|
||||
<Markdown
|
||||
dangerouslySetInnerHTML={{ __html: filteredResults }}
|
||||
/>
|
||||
</Content>
|
||||
</ResultWrapper>
|
||||
)
|
||||
else {
|
||||
setResults((prevItems) => prevItems.filter((_, index) => index !== key));
|
||||
}
|
||||
})
|
||||
:
|
||||
<NoResults>No results</NoResults>
|
||||
)
|
||||
:
|
||||
<Loader />
|
||||
}
|
||||
<SearchHeader>
|
||||
<TextField
|
||||
ref={inputRef}
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
onKeyDown={(e) => handleKeyDown(e)}
|
||||
placeholder={placeholder}
|
||||
autoFocus
|
||||
/>
|
||||
<EscapeInstruction onClick={() => setIsResultVisible(false)}>
|
||||
Esc
|
||||
</EscapeInstruction>
|
||||
</SearchHeader>
|
||||
<AskAIButton onClick={openWidget}>
|
||||
<img
|
||||
src="https://d3dg1063dc54p9.cloudfront.net/cute-docsgpt.png"
|
||||
alt="DocsGPT"
|
||||
width={24}
|
||||
height={24}
|
||||
/>
|
||||
<span>Ask the AI</span>
|
||||
</AskAIButton>
|
||||
<SearchResultsScroll>
|
||||
{!loading ? (
|
||||
results.length > 0 ? (
|
||||
results.map((res, key) => {
|
||||
const containsSource = res.source !== 'local';
|
||||
const processedResults = processMarkdownString(res.text, input);
|
||||
if (processedResults)
|
||||
return (
|
||||
<ResultWrapper
|
||||
key={key}
|
||||
onClick={() => {
|
||||
if (!containsSource) return;
|
||||
window.open(res.source, '_blank', 'noopener, noreferrer');
|
||||
}}
|
||||
>
|
||||
<div style={{ flex: 1 }}>
|
||||
<ContentWrapper>
|
||||
<IconTitleWrapper>
|
||||
<ReaderIcon className="title-icon" />
|
||||
<Title>{res.title}</Title>
|
||||
</IconTitleWrapper>
|
||||
<Content>
|
||||
{processedResults.map((element, index) => (
|
||||
<ContentSegment key={index}>
|
||||
<IconTitleWrapper>
|
||||
{element.tag === 'code' && <CodeIcon className="element-icon" />}
|
||||
{(element.tag === 'bulletList' || element.tag === 'numberedList') && <ListBulletIcon className="element-icon" />}
|
||||
{element.tag === 'text' && <TextAlignLeftIcon className="element-icon" />}
|
||||
{element.tag === 'heading' && <HeadingIcon className="element-icon" />}
|
||||
{element.tag === 'blockquote' && <QuoteIcon className="element-icon" />}
|
||||
</IconTitleWrapper>
|
||||
<div
|
||||
style={{ flex: 1 }}
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: DOMPurify.sanitize(element.content),
|
||||
}}
|
||||
/>
|
||||
</ContentSegment>
|
||||
))}
|
||||
</Content>
|
||||
</ContentWrapper>
|
||||
</div>
|
||||
</ResultWrapper>
|
||||
);
|
||||
return null;
|
||||
})
|
||||
) : (
|
||||
<NoResults>No results found</NoResults>
|
||||
)
|
||||
) : (
|
||||
<Loader />
|
||||
)}
|
||||
</SearchResultsScroll>
|
||||
</SearchResults>
|
||||
)
|
||||
}
|
||||
@@ -402,4 +566,4 @@ export const SearchBar = ({
|
||||
</Main>
|
||||
</ThemeProvider>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,9 +44,10 @@ export interface WidgetCoreProps extends WidgetProps {
|
||||
export interface SearchBarProps {
|
||||
apiHost?: string;
|
||||
apiKey?: string;
|
||||
theme?:THEME;
|
||||
placeholder?:string;
|
||||
width?:string;
|
||||
theme?: THEME;
|
||||
placeholder?: string;
|
||||
width?: string;
|
||||
buttonText?: string;
|
||||
}
|
||||
|
||||
export interface Result {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import MarkdownIt from "markdown-it";
|
||||
import DOMPurify from "dompurify";
|
||||
export const getOS = () => {
|
||||
const platform = window.navigator.platform;
|
||||
const userAgent = window.navigator.userAgent || window.navigator.vendor;
|
||||
@@ -27,61 +25,127 @@ export const getOS = () => {
|
||||
return 'other';
|
||||
};
|
||||
|
||||
export const preprocessSearchResultsToHTML = (text: string, keyword: string) => {
|
||||
const md = new MarkdownIt();
|
||||
const htmlString = md.render(text);
|
||||
interface ParsedElement {
|
||||
content: string;
|
||||
tag: string;
|
||||
}
|
||||
|
||||
// Container for processed HTML
|
||||
const filteredResults = document.createElement("div");
|
||||
filteredResults.innerHTML = htmlString;
|
||||
export const processMarkdownString = (markdown: string, keyword?: string): ParsedElement[] => {
|
||||
const lines = markdown.trim().split('\n');
|
||||
const keywordLower = keyword?.toLowerCase();
|
||||
|
||||
if (!processNode(filteredResults, keyword.trim())) return null;
|
||||
const escapeRegExp = (str: string) => str.replace(/[-\/\\^$*+?.()|[\]{}]/g, '\\$&');
|
||||
const escapedKeyword = keyword ? escapeRegExp(keyword) : '';
|
||||
const keywordRegex = keyword ? new RegExp(`(${escapedKeyword})`, 'gi') : null;
|
||||
|
||||
return filteredResults.innerHTML.trim() ? filteredResults.outerHTML : null;
|
||||
};
|
||||
let isInCodeBlock = false;
|
||||
let codeBlockContent: string[] = [];
|
||||
let matchingLines: ParsedElement[] = [];
|
||||
let firstLine: ParsedElement | null = null;
|
||||
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const trimmedLine = lines[i].trim();
|
||||
if (!trimmedLine) continue;
|
||||
|
||||
if (trimmedLine.startsWith('```')) {
|
||||
if (!isInCodeBlock) {
|
||||
isInCodeBlock = true;
|
||||
codeBlockContent = [];
|
||||
} else {
|
||||
isInCodeBlock = false;
|
||||
const codeContent = codeBlockContent.join('\n');
|
||||
const parsedElement: ParsedElement = {
|
||||
content: codeContent,
|
||||
tag: 'code'
|
||||
};
|
||||
|
||||
// Recursive function to process nodes
|
||||
const processNode = (node: Node, keyword: string): boolean => {
|
||||
if (!firstLine) {
|
||||
firstLine = parsedElement;
|
||||
}
|
||||
|
||||
const keywordRegex = new RegExp(`(${keyword})`, "gi");
|
||||
if (node.nodeType === Node.TEXT_NODE) {
|
||||
const textContent = node.textContent || "";
|
||||
|
||||
if (textContent.toLowerCase().includes(keyword.toLowerCase())) {
|
||||
const highlightedHTML = textContent.replace(
|
||||
keywordRegex,
|
||||
`<mark>$1</mark>`
|
||||
);
|
||||
const tempContainer = document.createElement("div");
|
||||
tempContainer.innerHTML = highlightedHTML;
|
||||
|
||||
// Replace the text node with highlighted content
|
||||
while (tempContainer.firstChild) {
|
||||
node.parentNode?.insertBefore(tempContainer.firstChild, node);
|
||||
if (keywordLower && codeContent.toLowerCase().includes(keywordLower)) {
|
||||
parsedElement.content = parsedElement.content.replace(keywordRegex!, '<span class="highlight">$1</span>');
|
||||
matchingLines.push(parsedElement);
|
||||
}
|
||||
}
|
||||
node.parentNode?.removeChild(node);
|
||||
|
||||
return true;
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
} else if (node.nodeType === Node.ELEMENT_NODE) {
|
||||
if (isInCodeBlock) {
|
||||
codeBlockContent.push(trimmedLine);
|
||||
continue;
|
||||
}
|
||||
|
||||
const children = Array.from(node.childNodes);
|
||||
let hasKeyword = false;
|
||||
let parsedElement: ParsedElement | null = null;
|
||||
|
||||
children.forEach((child) => {
|
||||
if (!processNode(child, keyword)) {
|
||||
node.removeChild(child);
|
||||
} else {
|
||||
hasKeyword = true;
|
||||
}
|
||||
});
|
||||
const headingMatch = trimmedLine.match(/^(#{1,6})\s+(.+)$/);
|
||||
const bulletMatch = trimmedLine.match(/^[-*]\s+(.+)$/);
|
||||
const numberedMatch = trimmedLine.match(/^\d+\.\s+(.+)$/);
|
||||
const blockquoteMatch = trimmedLine.match(/^>+\s*(.+)$/);
|
||||
|
||||
return hasKeyword;
|
||||
let content = trimmedLine;
|
||||
|
||||
if (headingMatch) {
|
||||
content = headingMatch[2];
|
||||
parsedElement = {
|
||||
content: content,
|
||||
tag: 'heading'
|
||||
};
|
||||
} else if (bulletMatch) {
|
||||
content = bulletMatch[1];
|
||||
parsedElement = {
|
||||
content: content,
|
||||
tag: 'bulletList'
|
||||
};
|
||||
} else if (numberedMatch) {
|
||||
content = numberedMatch[1];
|
||||
parsedElement = {
|
||||
content: content,
|
||||
tag: 'numberedList'
|
||||
};
|
||||
} else if (blockquoteMatch) {
|
||||
content = blockquoteMatch[1];
|
||||
parsedElement = {
|
||||
content: content,
|
||||
tag: 'blockquote'
|
||||
};
|
||||
} else {
|
||||
parsedElement = {
|
||||
content: content,
|
||||
tag: 'text'
|
||||
};
|
||||
}
|
||||
|
||||
if (!firstLine) {
|
||||
firstLine = parsedElement;
|
||||
}
|
||||
|
||||
if (keywordLower && parsedElement.content.toLowerCase().includes(keywordLower)) {
|
||||
parsedElement.content = parsedElement.content.replace(keywordRegex!, '<span class="highlight">$1</span>');
|
||||
matchingLines.push(parsedElement);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
if (isInCodeBlock && codeBlockContent.length > 0) {
|
||||
const codeContent = codeBlockContent.join('\n');
|
||||
const parsedElement: ParsedElement = {
|
||||
content: codeContent,
|
||||
tag: 'code'
|
||||
};
|
||||
|
||||
if (!firstLine) {
|
||||
firstLine = parsedElement;
|
||||
}
|
||||
|
||||
if (keywordLower && codeContent.toLowerCase().includes(keywordLower)) {
|
||||
parsedElement.content = parsedElement.content.replace(keywordRegex!, '<span class="highlight">$1</span>');
|
||||
matchingLines.push(parsedElement);
|
||||
}
|
||||
}
|
||||
|
||||
if (keywordLower && matchingLines.length > 0) {
|
||||
return matchingLines;
|
||||
}
|
||||
|
||||
return firstLine ? [firstLine] : [];
|
||||
};
|
||||
|
||||
1027
frontend/package-lock.json
generated
1027
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -21,8 +21,8 @@
|
||||
"dependencies": {
|
||||
"@reduxjs/toolkit": "^2.2.7",
|
||||
"chart.js": "^4.4.4",
|
||||
"i18next": "^23.15.1",
|
||||
"i18next-browser-languagedetector": "^8.0.0",
|
||||
"i18next": "^24.2.0",
|
||||
"i18next-browser-languagedetector": "^8.0.2",
|
||||
"prop-types": "^15.8.1",
|
||||
"react": "^18.2.0",
|
||||
"react-chartjs-2": "^5.2.0",
|
||||
@@ -30,10 +30,10 @@
|
||||
"react-dom": "^18.3.1",
|
||||
"react-helmet": "^6.1.0",
|
||||
"react-dropzone": "^14.3.5",
|
||||
"react-i18next": "^15.0.2",
|
||||
"react-i18next": "^15.4.0",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-redux": "^8.0.5",
|
||||
"react-router-dom": "^6.8.1",
|
||||
"react-router-dom": "^7.1.1",
|
||||
"react-syntax-highlighter": "^15.5.0",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
@@ -55,15 +55,15 @@
|
||||
"eslint-plugin-n": "^15.7.0",
|
||||
"eslint-plugin-prettier": "^5.2.1",
|
||||
"eslint-plugin-promise": "^6.6.0",
|
||||
"eslint-plugin-react": "^7.37.2",
|
||||
"eslint-plugin-react": "^7.37.3",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"husky": "^8.0.0",
|
||||
"lint-staged": "^15.2.10",
|
||||
"postcss": "^8.4.41",
|
||||
"lint-staged": "^15.3.0",
|
||||
"postcss": "^8.4.49",
|
||||
"prettier": "^3.3.3",
|
||||
"prettier-plugin-tailwindcss": "^0.6.8",
|
||||
"tailwindcss": "^3.4.15",
|
||||
"typescript": "^5.6.2",
|
||||
"typescript": "^5.7.2",
|
||||
"vite": "^5.4.11",
|
||||
"vite-plugin-svgr": "^4.2.0"
|
||||
}
|
||||
|
||||
1
frontend/public/toolIcons/tool_cryptoprice.svg
Normal file
1
frontend/public/toolIcons/tool_cryptoprice.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 122.88 122.88"><path d="M17.89 0h88.9c8.85 0 16.1 7.24 16.1 16.1v90.68c0 8.85-7.24 16.1-16.1 16.1H16.1c-8.85 0-16.1-7.24-16.1-16.1v-88.9C0 8.05 8.05 0 17.89 0zm57.04 66.96l16.46 4.96c-1.1 4.61-2.84 8.47-5.23 11.56-2.38 3.1-5.32 5.43-8.85 7-3.52 1.57-8.01 2.36-13.45 2.36-6.62 0-12.01-.96-16.21-2.87-4.19-1.92-7.79-5.3-10.83-10.13-3.04-4.82-4.57-11.02-4.57-18.54 0-10.04 2.67-17.76 8.02-23.17 5.36-5.39 12.93-8.09 22.71-8.09 7.65 0 13.68 1.54 18.06 4.64 4.37 3.1 7.64 7.85 9.76 14.27l-16.55 3.66c-.58-1.84-1.19-3.18-1.82-4.03-1.06-1.43-2.35-2.53-3.86-3.3-1.53-.78-3.22-1.16-5.11-1.16-4.27 0-7.54 1.71-9.8 5.12-1.71 2.53-2.57 6.52-2.57 11.94 0 6.73 1.02 11.33 3.07 13.83 2.05 2.49 4.92 3.73 8.63 3.73 3.59 0 6.31-1 8.15-3.03 1.83-1.99 3.16-4.92 3.99-8.75z" fill-rule="evenodd" clip-rule="evenodd"/></svg>
|
||||
|
After Width: | Height: | Size: 855 B |
10
frontend/public/toolIcons/tool_telegram.svg
Normal file
10
frontend/public/toolIcons/tool_telegram.svg
Normal file
@@ -0,0 +1,10 @@
|
||||
<svg width="24" height="25" viewBox="0 0 24 25" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12 0.5C8.81812 0.5 5.76375 1.76506 3.51562 4.01469C1.2652 6.26522 0.000643966 9.31734 0 12.5C0 15.6813 1.26562 18.7357 3.51562 20.9853C5.76375 23.2349 8.81812 24.5 12 24.5C15.1819 24.5 18.2362 23.2349 20.4844 20.9853C22.7344 18.7357 24 15.6813 24 12.5C24 9.31869 22.7344 6.26431 20.4844 4.01469C18.2362 1.76506 15.1819 0.5 12 0.5Z" fill="url(#paint0_linear_5586_9958)"/>
|
||||
<path d="M5.43282 12.373C8.93157 10.849 11.2641 9.8443 12.4303 9.3588C15.7641 7.97261 16.4559 7.73186 16.9078 7.7237C17.0072 7.72211 17.2284 7.74667 17.3728 7.86339C17.4928 7.96183 17.5266 8.09495 17.5434 8.18842C17.5584 8.2818 17.5791 8.49461 17.5622 8.66074C17.3822 10.5582 16.6003 15.1629 16.2028 17.2882C16.0359 18.1874 15.7041 18.4889 15.3834 18.5184C14.6859 18.5825 14.1572 18.0579 13.4822 17.6155C12.4266 16.9231 11.8303 16.4922 10.8047 15.8167C9.6197 15.0359 10.3884 14.6067 11.0634 13.9055C11.2397 13.7219 14.3109 10.9291 14.3691 10.6758C14.3766 10.6441 14.3841 10.526 14.3128 10.4637C14.2434 10.4013 14.1403 10.4227 14.0653 10.4395C13.9584 10.4635 12.2728 11.5788 9.00282 13.7851C8.52469 14.114 8.09157 14.2743 7.70157 14.2659C7.27407 14.2567 6.44907 14.0236 5.83595 13.8245C5.08595 13.5802 4.48782 13.451 4.54032 13.036C4.56657 12.82 4.8647 12.599 5.43282 12.373Z" fill="white"/>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_5586_9958" x1="1200" y1="0.5" x2="1200" y2="2400.5" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#2AABEE"/>
|
||||
<stop offset="1" stop-color="#229ED9"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
@@ -18,6 +18,11 @@ const endpoints = {
|
||||
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
|
||||
LOGS: `/api/get_user_logs`,
|
||||
MANAGE_SYNC: '/api/manage_sync',
|
||||
GET_AVAILABLE_TOOLS: '/api/available_tools',
|
||||
GET_USER_TOOLS: '/api/get_tools',
|
||||
CREATE_TOOL: '/api/create_tool',
|
||||
UPDATE_TOOL_STATUS: '/api/update_tool_status',
|
||||
UPDATE_TOOL: '/api/update_tool',
|
||||
},
|
||||
CONVERSATION: {
|
||||
ANSWER: '/api/answer',
|
||||
|
||||
@@ -35,6 +35,16 @@ const userService = {
|
||||
apiClient.post(endpoints.USER.LOGS, data),
|
||||
manageSync: (data: any): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MANAGE_SYNC, data),
|
||||
getAvailableTools: (): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS),
|
||||
getUserTools: (): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.GET_USER_TOOLS),
|
||||
createTool: (data: any): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.CREATE_TOOL, data),
|
||||
updateToolStatus: (data: any): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data),
|
||||
updateTool: (data: any): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.UPDATE_TOOL, data),
|
||||
};
|
||||
|
||||
export default userService;
|
||||
|
||||
3
frontend/src/assets/cogwheel.svg
Normal file
3
frontend/src/assets/cogwheel.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="16" height="17" viewBox="0 0 16 17" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8.00182 11.2502C7.23838 11.2502 6.50621 10.9552 5.96637 10.4301C5.42654 9.90499 5.12327 9.1928 5.12327 8.4502C5.12327 7.70759 5.42654 6.9954 5.96637 6.4703C6.50621 5.9452 7.23838 5.6502 8.00182 5.6502C8.76525 5.6502 9.49743 5.9452 10.0373 6.4703C10.5771 6.9954 10.8804 7.70759 10.8804 8.4502C10.8804 9.1928 10.5771 9.90499 10.0373 10.4301C9.49743 10.9552 8.76525 11.2502 8.00182 11.2502ZM14.1126 9.2262C14.1455 8.9702 14.1701 8.7142 14.1701 8.4502C14.1701 8.1862 14.1455 7.9222 14.1126 7.6502L15.8479 6.3462C16.0042 6.2262 16.0453 6.0102 15.9466 5.8342L14.3017 3.0662C14.203 2.8902 13.981 2.8182 13.8 2.8902L11.7522 3.6902C11.3245 3.3782 10.8804 3.1062 10.3622 2.9062L10.0579 0.786197C10.0412 0.69197 9.99076 0.606538 9.91549 0.545038C9.84022 0.483538 9.745 0.449939 9.6467 0.450197H6.35693C6.15132 0.450197 5.97861 0.594197 5.94571 0.786197L5.64141 2.9062C5.12327 3.1062 4.67915 3.3782 4.25148 3.6902L2.2036 2.8902C2.02266 2.8182 1.8006 2.8902 1.70191 3.0662L0.0570212 5.8342C-0.0498963 6.0102 -0.00054964 6.2262 0.155714 6.3462L1.89107 7.6502C1.85817 7.9222 1.8335 8.1862 1.8335 8.4502C1.8335 8.7142 1.85817 8.9702 1.89107 9.2262L0.155714 10.5542C-0.00054964 10.6742 -0.0498963 10.8902 0.0570212 11.0662L1.70191 13.8342C1.8006 14.0102 2.02266 14.0742 2.2036 14.0102L4.25148 13.2022C4.67915 13.5222 5.12327 13.7942 5.64141 13.9942L5.94571 16.1142C5.97861 16.3062 6.15132 16.4502 6.35693 16.4502H9.6467C9.85231 16.4502 10.025 16.3062 10.0579 16.1142L10.3622 13.9942C10.8804 13.7862 11.3245 13.5222 11.7522 13.2022L13.8 14.0102C13.981 14.0742 14.203 14.0102 14.3017 13.8342L15.9466 11.0662C16.0453 10.8902 16.0042 10.6742 15.8479 10.5542L14.1126 9.2262Z" fill="#747474"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.7 KiB |
@@ -13,6 +13,7 @@ const useTabs = () => {
|
||||
t('settings.apiKeys.label'),
|
||||
t('settings.analytics.label'),
|
||||
t('settings.logs.label'),
|
||||
t('settings.tools.label'),
|
||||
];
|
||||
return tabs;
|
||||
};
|
||||
|
||||
@@ -155,8 +155,22 @@ export default function Conversation() {
|
||||
const handleFeedback = (query: Query, feedback: FEEDBACK, index: number) => {
|
||||
const prevFeedback = query.feedback;
|
||||
dispatch(updateQuery({ index, query: { feedback } }));
|
||||
handleSendFeedback(query.prompt, query.response!, feedback).catch(() =>
|
||||
dispatch(updateQuery({ index, query: { feedback: prevFeedback } })),
|
||||
handleSendFeedback(
|
||||
query.prompt,
|
||||
query.response!,
|
||||
feedback,
|
||||
conversationId as string,
|
||||
index,
|
||||
).catch(() =>
|
||||
handleSendFeedback(
|
||||
query.prompt,
|
||||
query.response!,
|
||||
feedback,
|
||||
conversationId as string,
|
||||
index,
|
||||
).catch(() =>
|
||||
dispatch(updateQuery({ index, query: { feedback: prevFeedback } })),
|
||||
),
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -436,7 +436,8 @@ const ConversationBubble = forwardRef<
|
||||
feedback === 'LIKE' || type !== 'ERROR'
|
||||
? 'group-hover:lg:visible'
|
||||
: ''
|
||||
}`}
|
||||
}
|
||||
${feedback === 'DISLIKE' && type !== 'ERROR' ? 'hidden' : ''}`}
|
||||
>
|
||||
<div>
|
||||
<div
|
||||
@@ -454,9 +455,15 @@ const ConversationBubble = forwardRef<
|
||||
: 'fill-none stroke-gray-4000'
|
||||
}`}
|
||||
onClick={() => {
|
||||
handleFeedback?.('LIKE');
|
||||
setIsLikeClicked(true);
|
||||
setIsDislikeClicked(false);
|
||||
if (feedback === undefined || feedback === null) {
|
||||
handleFeedback?.('LIKE');
|
||||
setIsLikeClicked(true);
|
||||
setIsDislikeClicked(false);
|
||||
} else if (feedback === 'LIKE') {
|
||||
handleFeedback?.(null);
|
||||
setIsLikeClicked(false);
|
||||
setIsDislikeClicked(false);
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => setIsLikeHovered(true)}
|
||||
onMouseLeave={() => setIsLikeHovered(false)}
|
||||
@@ -471,7 +478,7 @@ const ConversationBubble = forwardRef<
|
||||
feedback === 'DISLIKE' || type !== 'ERROR'
|
||||
? 'group-hover:lg:visible'
|
||||
: ''
|
||||
}`}
|
||||
} ${feedback === 'LIKE' && type !== 'ERROR' ? ' hidden' : ''} `}
|
||||
>
|
||||
<div>
|
||||
<div
|
||||
@@ -488,9 +495,15 @@ const ConversationBubble = forwardRef<
|
||||
: 'fill-none stroke-gray-4000'
|
||||
}`}
|
||||
onClick={() => {
|
||||
handleFeedback?.('DISLIKE');
|
||||
setIsDislikeClicked(true);
|
||||
setIsLikeClicked(false);
|
||||
if (feedback === undefined || feedback === null) {
|
||||
handleFeedback?.('DISLIKE');
|
||||
setIsDislikeClicked(true);
|
||||
setIsLikeClicked(false);
|
||||
} else if (feedback === 'DISLIKE') {
|
||||
handleFeedback?.(null);
|
||||
setIsLikeClicked(false);
|
||||
setIsDislikeClicked(false);
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => setIsDislikeHovered(true)}
|
||||
onMouseLeave={() => setIsDislikeHovered(false)}
|
||||
|
||||
@@ -202,12 +202,16 @@ export function handleSendFeedback(
|
||||
prompt: string,
|
||||
response: string,
|
||||
feedback: FEEDBACK,
|
||||
conversation_id: string,
|
||||
prompt_index: number,
|
||||
) {
|
||||
return conversationService
|
||||
.feedback({
|
||||
question: prompt,
|
||||
answer: response,
|
||||
feedback: feedback,
|
||||
conversation_id: conversation_id,
|
||||
question_index: prompt_index,
|
||||
})
|
||||
.then((response) => {
|
||||
if (response.ok) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR';
|
||||
export type Status = 'idle' | 'loading' | 'failed';
|
||||
export type FEEDBACK = 'LIKE' | 'DISLIKE';
|
||||
export type FEEDBACK = 'LIKE' | 'DISLIKE' | null;
|
||||
|
||||
export interface Message {
|
||||
text: string;
|
||||
|
||||
@@ -72,6 +72,15 @@ body.dark {
|
||||
.table-default td:last-child {
|
||||
@apply border-r-0; /* Ensure no right border on the last column */
|
||||
}
|
||||
|
||||
.table-default th,
|
||||
.table-default td {
|
||||
min-width: 150px;
|
||||
max-width: 320px;
|
||||
overflow: auto;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: grey transparent;
|
||||
}
|
||||
}
|
||||
|
||||
/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */
|
||||
|
||||
@@ -73,6 +73,9 @@
|
||||
},
|
||||
"logs": {
|
||||
"label": "Logs"
|
||||
},
|
||||
"tools": {
|
||||
"label": "Tools"
|
||||
}
|
||||
},
|
||||
"modals": {
|
||||
|
||||
136
frontend/src/modals/AddToolModal.tsx
Normal file
136
frontend/src/modals/AddToolModal.tsx
Normal file
@@ -0,0 +1,136 @@
|
||||
import React from 'react';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import Exit from '../assets/exit.svg';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { AvailableTool } from './types';
|
||||
import ConfigToolModal from './ConfigToolModal';
|
||||
|
||||
export default function AddToolModal({
|
||||
message,
|
||||
modalState,
|
||||
setModalState,
|
||||
getUserTools,
|
||||
}: {
|
||||
message: string;
|
||||
modalState: ActiveState;
|
||||
setModalState: (state: ActiveState) => void;
|
||||
getUserTools: () => void;
|
||||
}) {
|
||||
const [availableTools, setAvailableTools] = React.useState<AvailableTool[]>(
|
||||
[],
|
||||
);
|
||||
const [selectedTool, setSelectedTool] = React.useState<AvailableTool | null>(
|
||||
null,
|
||||
);
|
||||
const [configModalState, setConfigModalState] =
|
||||
React.useState<ActiveState>('INACTIVE');
|
||||
|
||||
const getAvailableTools = () => {
|
||||
userService
|
||||
.getAvailableTools()
|
||||
.then((res) => {
|
||||
return res.json();
|
||||
})
|
||||
.then((data) => {
|
||||
setAvailableTools(data.data);
|
||||
});
|
||||
};
|
||||
|
||||
const handleAddTool = (tool: AvailableTool) => {
|
||||
if (Object.keys(tool.configRequirements).length === 0) {
|
||||
userService
|
||||
.createTool({
|
||||
name: tool.name,
|
||||
displayName: tool.displayName,
|
||||
description: tool.description,
|
||||
config: {},
|
||||
actions: tool.actions,
|
||||
status: true,
|
||||
})
|
||||
.then((res) => {
|
||||
if (res.status === 200) {
|
||||
getUserTools();
|
||||
setModalState('INACTIVE');
|
||||
}
|
||||
});
|
||||
} else {
|
||||
setModalState('INACTIVE');
|
||||
setConfigModalState('ACTIVE');
|
||||
}
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
if (modalState === 'ACTIVE') getAvailableTools();
|
||||
}, [modalState]);
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
className={`${
|
||||
modalState === 'ACTIVE' ? 'visible' : 'hidden'
|
||||
} fixed top-0 left-0 z-30 h-screen w-screen bg-gray-alpha flex items-center justify-center`}
|
||||
>
|
||||
<article className="flex h-[85vh] w-[90vw] md:w-[75vw] flex-col gap-4 rounded-2xl bg-[#FBFBFB] shadow-lg dark:bg-[#26272E]">
|
||||
<div className="relative">
|
||||
<button
|
||||
className="absolute top-3 right-4 m-2 w-3"
|
||||
onClick={() => {
|
||||
setModalState('INACTIVE');
|
||||
}}
|
||||
>
|
||||
<img className="filter dark:invert" src={Exit} />
|
||||
</button>
|
||||
<div className="p-6">
|
||||
<h2 className="font-semibold text-xl text-jet dark:text-bright-gray px-3">
|
||||
Select a tool to set up
|
||||
</h2>
|
||||
<div className="mt-5 grid grid-cols-3 gap-4 h-[73vh] overflow-auto px-3 py-px">
|
||||
{availableTools.map((tool, index) => (
|
||||
<div
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
key={index}
|
||||
className="h-52 w-full p-6 border rounded-2xl border-silver dark:border-[#4D4E58] flex flex-col justify-between dark:bg-[#32333B] cursor-pointer"
|
||||
onClick={() => {
|
||||
setSelectedTool(tool);
|
||||
handleAddTool(tool);
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
setSelectedTool(tool);
|
||||
handleAddTool(tool);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="px-1 w-full flex items-center justify-between">
|
||||
<img
|
||||
src={`/toolIcons/tool_${tool.name}.svg`}
|
||||
className="h-8 w-8"
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-[9px]">
|
||||
<p className="px-1 text-sm font-semibold text-eerie-black dark:text-white leading-relaxed capitalize">
|
||||
{tool.displayName}
|
||||
</p>
|
||||
<p className="mt-1 px-1 h-24 overflow-auto text-sm text-gray-600 dark:text-[#8a8a8c] leading-relaxed">
|
||||
{tool.description}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</article>
|
||||
</div>
|
||||
<ConfigToolModal
|
||||
modalState={configModalState}
|
||||
setModalState={setConfigModalState}
|
||||
tool={selectedTool}
|
||||
getUserTools={getUserTools}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
95
frontend/src/modals/ConfigToolModal.tsx
Normal file
95
frontend/src/modals/ConfigToolModal.tsx
Normal file
@@ -0,0 +1,95 @@
|
||||
import React from 'react';
|
||||
|
||||
import Exit from '../assets/exit.svg';
|
||||
import Input from '../components/Input';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { AvailableTool } from './types';
|
||||
import userService from '../api/services/userService';
|
||||
|
||||
export default function ConfigToolModal({
|
||||
modalState,
|
||||
setModalState,
|
||||
tool,
|
||||
getUserTools,
|
||||
}: {
|
||||
modalState: ActiveState;
|
||||
setModalState: (state: ActiveState) => void;
|
||||
tool: AvailableTool | null;
|
||||
getUserTools: () => void;
|
||||
}) {
|
||||
const [authKey, setAuthKey] = React.useState<string>('');
|
||||
|
||||
const handleAddTool = (tool: AvailableTool) => {
|
||||
userService
|
||||
.createTool({
|
||||
name: tool.name,
|
||||
displayName: tool.displayName,
|
||||
description: tool.description,
|
||||
config: { token: authKey },
|
||||
actions: tool.actions,
|
||||
status: true,
|
||||
})
|
||||
.then(() => {
|
||||
setModalState('INACTIVE');
|
||||
getUserTools();
|
||||
});
|
||||
};
|
||||
return (
|
||||
<div
|
||||
className={`${
|
||||
modalState === 'ACTIVE' ? 'visible' : 'hidden'
|
||||
} fixed top-0 left-0 z-30 h-screen w-screen bg-gray-alpha flex items-center justify-center`}
|
||||
>
|
||||
<article className="flex w-11/12 sm:w-[512px] flex-col gap-4 rounded-2xl bg-white shadow-lg dark:bg-[#26272E]">
|
||||
<div className="relative">
|
||||
<button
|
||||
className="absolute top-3 right-4 m-2 w-3"
|
||||
onClick={() => {
|
||||
setModalState('INACTIVE');
|
||||
}}
|
||||
>
|
||||
<img className="filter dark:invert" src={Exit} />
|
||||
</button>
|
||||
<div className="p-6">
|
||||
<h2 className="font-semibold text-xl text-jet dark:text-bright-gray px-3">
|
||||
Tool Config
|
||||
</h2>
|
||||
<p className="mt-5 text-sm text-gray-600 dark:text-gray-400 px-3">
|
||||
Type: <span className="font-semibold">{tool?.name} </span>
|
||||
</p>
|
||||
<div className="mt-6 relative px-3">
|
||||
<span className="absolute left-5 -top-2 bg-white px-2 text-xs text-gray-4000 dark:bg-[#26272E] dark:text-silver">
|
||||
API Key / Oauth
|
||||
</span>
|
||||
<Input
|
||||
type="text"
|
||||
value={authKey}
|
||||
onChange={(e) => setAuthKey(e.target.value)}
|
||||
borderVariant="thin"
|
||||
placeholder="Enter API Key / Oauth"
|
||||
></Input>
|
||||
</div>
|
||||
<div className="mt-8 flex flex-row-reverse gap-1 px-3">
|
||||
<button
|
||||
onClick={() => {
|
||||
handleAddTool(tool as AvailableTool);
|
||||
}}
|
||||
className="rounded-3xl bg-purple-30 px-5 py-2 text-sm text-white transition-all hover:bg-[#6F3FD1]"
|
||||
>
|
||||
Add Tool
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setModalState('INACTIVE');
|
||||
}}
|
||||
className="cursor-pointer rounded-3xl px-5 py-2 text-sm font-medium hover:bg-gray-100 dark:bg-transparent dark:text-light-gray dark:hover:bg-[#767183]/50"
|
||||
>
|
||||
Close
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</article>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,3 +1,15 @@
|
||||
export type AvailableTool = {
|
||||
name: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
configRequirements: object;
|
||||
actions: {
|
||||
name: string;
|
||||
description: string;
|
||||
parameters: object;
|
||||
}[];
|
||||
};
|
||||
|
||||
export type WrapperModalProps = {
|
||||
children?: React.ReactNode;
|
||||
isPerformingTask?: boolean;
|
||||
|
||||
@@ -181,7 +181,7 @@ const Documents: React.FC<DocumentsProps> = ({
|
||||
{loading ? (
|
||||
<SkeletonLoader count={1} />
|
||||
) : (
|
||||
<div className="flex flex-col">
|
||||
<div className="flex flex-col">
|
||||
<div className="flex-grow">
|
||||
<div className="dark:border-silver/40 border-silver rounded-md border overflow-auto">
|
||||
<table className="min-w-full divide-y divide-silver dark:divide-silver/40 text-xs sm:text-sm ">
|
||||
|
||||
293
frontend/src/settings/ToolConfig.tsx
Normal file
293
frontend/src/settings/ToolConfig.tsx
Normal file
@@ -0,0 +1,293 @@
|
||||
import React from 'react';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
import Input from '../components/Input';
|
||||
import { UserTool } from './types';
|
||||
|
||||
export default function ToolConfig({
|
||||
tool,
|
||||
setTool,
|
||||
handleGoBack,
|
||||
}: {
|
||||
tool: UserTool;
|
||||
setTool: (tool: UserTool) => void;
|
||||
handleGoBack: () => void;
|
||||
}) {
|
||||
const [authKey, setAuthKey] = React.useState<string>(
|
||||
tool.config?.token || '',
|
||||
);
|
||||
|
||||
const handleCheckboxChange = (actionIndex: number, property: string) => {
|
||||
setTool({
|
||||
...tool,
|
||||
actions: tool.actions.map((action, index) => {
|
||||
if (index === actionIndex) {
|
||||
return {
|
||||
...action,
|
||||
parameters: {
|
||||
...action.parameters,
|
||||
properties: {
|
||||
...action.parameters.properties,
|
||||
[property]: {
|
||||
...action.parameters.properties[property],
|
||||
filled_by_llm:
|
||||
!action.parameters.properties[property].filled_by_llm,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
return action;
|
||||
}),
|
||||
});
|
||||
};
|
||||
|
||||
const handleSaveChanges = () => {
|
||||
userService
|
||||
.updateTool({
|
||||
id: tool.id,
|
||||
name: tool.name,
|
||||
displayName: tool.displayName,
|
||||
description: tool.description,
|
||||
config: { token: authKey },
|
||||
actions: tool.actions,
|
||||
status: tool.status,
|
||||
})
|
||||
.then(() => {
|
||||
handleGoBack();
|
||||
});
|
||||
};
|
||||
return (
|
||||
<div className="mt-8 flex flex-col gap-4">
|
||||
<div className="mb-4 flex items-center gap-3 text-eerie-black dark:text-bright-gray text-sm">
|
||||
<button
|
||||
className="text-sm text-gray-400 dark:text-gray-500 border dark:border-0 dark:bg-[#28292D] dark:hover:bg-[#2E2F34] p-3 rounded-full"
|
||||
onClick={handleGoBack}
|
||||
>
|
||||
<img src={ArrowLeft} alt="left-arrow" className="w-3 h-3" />
|
||||
</button>
|
||||
<p className="mt-px">Back to all tools</p>
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-sm font-semibold text-eerie-black dark:text-bright-gray">
|
||||
Type
|
||||
</p>
|
||||
<p className="mt-1 text-base font-normal text-eerie-black dark:text-bright-gray font-sans">
|
||||
{tool.name}
|
||||
</p>
|
||||
</div>
|
||||
<div className="mt-1">
|
||||
{Object.keys(tool?.config).length !== 0 && (
|
||||
<p className="text-sm font-semibold text-eerie-black dark:text-bright-gray">
|
||||
Authentication
|
||||
</p>
|
||||
)}
|
||||
<div className="mt-4 flex items-center gap-2">
|
||||
{Object.keys(tool?.config).length !== 0 && (
|
||||
<div className="relative w-96">
|
||||
<span className="absolute left-5 -top-2 bg-white px-2 text-xs text-gray-4000 dark:bg-[#26272E] dark:text-silver">
|
||||
API Key / Oauth
|
||||
</span>
|
||||
<Input
|
||||
type="text"
|
||||
value={authKey}
|
||||
onChange={(e) => setAuthKey(e.target.value)}
|
||||
borderVariant="thin"
|
||||
placeholder="Enter API Key / Oauth"
|
||||
></Input>
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
className="rounded-full h-10 w-36 bg-purple-30 text-white hover:bg-[#6F3FD1] text-nowrap text-sm"
|
||||
onClick={handleSaveChanges}
|
||||
>
|
||||
Save changes
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="mx-1 my-2 h-[0.8px] w-full rounded-full bg-[#C4C4C4]/40 lg:w-[95%] "></div>
|
||||
<p className="text-base font-semibold text-eerie-black dark:text-bright-gray">
|
||||
Actions
|
||||
</p>
|
||||
<div className="flex flex-col gap-10">
|
||||
{tool.actions.map((action, actionIndex) => {
|
||||
return (
|
||||
<div
|
||||
key={actionIndex}
|
||||
className="w-full border border-silver dark:border-silver/40 rounded-xl"
|
||||
>
|
||||
<div className="h-10 bg-[#F9F9F9] dark:bg-[#28292D] rounded-t-xl border-b border-silver dark:border-silver/40 flex items-center justify-between px-5">
|
||||
<p className="font-semibold text-eerie-black dark:text-bright-gray">
|
||||
{action.name}
|
||||
</p>
|
||||
<label
|
||||
htmlFor={`actionToggle-${actionIndex}`}
|
||||
className="relative inline-block h-6 w-10 cursor-pointer rounded-full bg-gray-300 dark:bg-[#D2D5DA33]/20 transition [-webkit-tap-highlight-color:_transparent] has-[:checked]:bg-[#0C9D35CC] has-[:checked]:dark:bg-[#0C9D35CC]"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
id={`actionToggle-${actionIndex}`}
|
||||
className="peer sr-only"
|
||||
checked={action.active}
|
||||
onChange={() => {
|
||||
setTool({
|
||||
...tool,
|
||||
actions: tool.actions.map((act, index) => {
|
||||
if (index === actionIndex) {
|
||||
return { ...act, active: !act.active };
|
||||
}
|
||||
return act;
|
||||
}),
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<span className="absolute inset-y-0 start-0 m-[3px] size-[18px] rounded-full bg-white transition-all peer-checked:start-4"></span>
|
||||
</label>
|
||||
</div>
|
||||
<div className="mt-5 relative px-5 w-96">
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter description"
|
||||
value={action.description}
|
||||
onChange={(e) => {
|
||||
setTool({
|
||||
...tool,
|
||||
actions: tool.actions.map((act, index) => {
|
||||
if (index === actionIndex) {
|
||||
return {
|
||||
...act,
|
||||
description: e.target.value,
|
||||
};
|
||||
}
|
||||
return act;
|
||||
}),
|
||||
});
|
||||
}}
|
||||
borderVariant="thin"
|
||||
></Input>
|
||||
</div>
|
||||
<div className="px-5 py-4">
|
||||
<table className="table-default">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Field Name</th>
|
||||
<th>Field Type</th>
|
||||
<th>Filled by LLM</th>
|
||||
<th>FIeld description</th>
|
||||
<th>Value</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{Object.entries(action.parameters?.properties).map(
|
||||
(param, index) => {
|
||||
const uniqueKey = `${actionIndex}-${param[0]}`;
|
||||
return (
|
||||
<tr key={index} className="text-nowrap font-normal">
|
||||
<td>{param[0]}</td>
|
||||
<td>{param[1].type}</td>
|
||||
<td>
|
||||
<label
|
||||
htmlFor={uniqueKey}
|
||||
className="ml-[10px] flex cursor-pointer items-start gap-4"
|
||||
>
|
||||
<div className="flex items-center">
|
||||
​
|
||||
<input
|
||||
checked={param[1].filled_by_llm}
|
||||
id={uniqueKey}
|
||||
type="checkbox"
|
||||
className="size-4 rounded border-gray-300 bg-transparent"
|
||||
onChange={() =>
|
||||
handleCheckboxChange(
|
||||
actionIndex,
|
||||
param[0],
|
||||
)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</label>
|
||||
</td>
|
||||
<td className="w-10">
|
||||
<input
|
||||
key={uniqueKey}
|
||||
value={param[1].description}
|
||||
className="bg-transparent border border-silver dark:border-silver/40 outline-none px-2 py-1 rounded-lg text-sm"
|
||||
onChange={(e) => {
|
||||
setTool({
|
||||
...tool,
|
||||
actions: tool.actions.map(
|
||||
(act, index) => {
|
||||
if (index === actionIndex) {
|
||||
return {
|
||||
...act,
|
||||
parameters: {
|
||||
...act.parameters,
|
||||
properties: {
|
||||
...act.parameters.properties,
|
||||
[param[0]]: {
|
||||
...act.parameters
|
||||
.properties[param[0]],
|
||||
description: e.target.value,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
return act;
|
||||
},
|
||||
),
|
||||
});
|
||||
}}
|
||||
></input>
|
||||
</td>
|
||||
<td>
|
||||
<input
|
||||
value={param[1].value}
|
||||
key={uniqueKey}
|
||||
disabled={param[1].filled_by_llm}
|
||||
className={`bg-transparent border border-silver dark:border-silver/40 outline-none px-2 py-1 rounded-lg text-sm ${param[1].filled_by_llm ? 'opacity-50' : ''}`}
|
||||
onChange={(e) => {
|
||||
setTool({
|
||||
...tool,
|
||||
actions: tool.actions.map(
|
||||
(act, index) => {
|
||||
if (index === actionIndex) {
|
||||
return {
|
||||
...act,
|
||||
parameters: {
|
||||
...act.parameters,
|
||||
properties: {
|
||||
...act.parameters.properties,
|
||||
[param[0]]: {
|
||||
...act.parameters
|
||||
.properties[param[0]],
|
||||
value: e.target.value,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
return act;
|
||||
},
|
||||
),
|
||||
});
|
||||
}}
|
||||
></input>
|
||||
</td>
|
||||
</tr>
|
||||
);
|
||||
},
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
157
frontend/src/settings/Tools.tsx
Normal file
157
frontend/src/settings/Tools.tsx
Normal file
@@ -0,0 +1,157 @@
|
||||
import React from 'react';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import CogwheelIcon from '../assets/cogwheel.svg';
|
||||
import Input from '../components/Input';
|
||||
import AddToolModal from '../modals/AddToolModal';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { UserTool } from './types';
|
||||
import ToolConfig from './ToolConfig';
|
||||
|
||||
export default function Tools() {
|
||||
const [searchTerm, setSearchTerm] = React.useState('');
|
||||
const [addToolModalState, setAddToolModalState] =
|
||||
React.useState<ActiveState>('INACTIVE');
|
||||
const [userTools, setUserTools] = React.useState<UserTool[]>([]);
|
||||
const [selectedTool, setSelectedTool] = React.useState<UserTool | null>(null);
|
||||
|
||||
const getUserTools = () => {
|
||||
userService
|
||||
.getUserTools()
|
||||
.then((res) => {
|
||||
return res.json();
|
||||
})
|
||||
.then((data) => {
|
||||
setUserTools(data.tools);
|
||||
});
|
||||
};
|
||||
|
||||
const updateToolStatus = (toolId: string, newStatus: boolean) => {
|
||||
userService
|
||||
.updateToolStatus({ id: toolId, status: newStatus })
|
||||
.then(() => {
|
||||
setUserTools((prevTools) =>
|
||||
prevTools.map((tool) =>
|
||||
tool.id === toolId ? { ...tool, status: newStatus } : tool,
|
||||
),
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to update tool status:', error);
|
||||
});
|
||||
};
|
||||
|
||||
const handleSettingsClick = (tool: UserTool) => {
|
||||
setSelectedTool(tool);
|
||||
};
|
||||
|
||||
const handleGoBack = () => {
|
||||
setSelectedTool(null);
|
||||
getUserTools();
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
getUserTools();
|
||||
}, []);
|
||||
return (
|
||||
<div>
|
||||
{selectedTool ? (
|
||||
<ToolConfig
|
||||
tool={selectedTool}
|
||||
setTool={setSelectedTool}
|
||||
handleGoBack={handleGoBack}
|
||||
/>
|
||||
) : (
|
||||
<div className="mt-8">
|
||||
<div className="flex flex-col relative">
|
||||
<div className="my-3 flex justify-between items-center gap-1">
|
||||
<div className="p-1">
|
||||
<Input
|
||||
maxLength={256}
|
||||
placeholder="Search..."
|
||||
name="Document-search-input"
|
||||
type="text"
|
||||
id="document-search-input"
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
className="rounded-full w-40 bg-purple-30 px-4 py-3 text-white hover:bg-[#6F3FD1] text-nowrap"
|
||||
onClick={() => {
|
||||
setAddToolModalState('ACTIVE');
|
||||
}}
|
||||
>
|
||||
Add Tool
|
||||
</button>
|
||||
</div>
|
||||
<div className="grid grid-cols-2 lg:grid-cols-3 gap-6">
|
||||
{userTools
|
||||
.filter((tool) =>
|
||||
tool.displayName
|
||||
.toLowerCase()
|
||||
.includes(searchTerm.toLowerCase()),
|
||||
)
|
||||
.map((tool, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="relative h-56 w-full p-6 border rounded-2xl border-silver dark:border-silver/40 flex flex-col justify-between"
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="w-full flex items-center justify-between">
|
||||
<img
|
||||
src={`/toolIcons/tool_${tool.name}.svg`}
|
||||
className="h-8 w-8"
|
||||
/>
|
||||
<button
|
||||
className="absolute top-3 right-3 cursor-pointer"
|
||||
onClick={() => handleSettingsClick(tool)}
|
||||
>
|
||||
<img
|
||||
src={CogwheelIcon}
|
||||
alt="settings"
|
||||
className="h-[19px] w-[19px]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
<div className="mt-[9px]">
|
||||
<p className="text-sm font-semibold text-eerie-black dark:text-[#EEEEEE] leading-relaxed">
|
||||
{tool.displayName}
|
||||
</p>
|
||||
<p className="mt-1 h-16 overflow-auto text-[13px] text-gray-600 dark:text-gray-400 leading-relaxed pr-1">
|
||||
{tool.description}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="absolute bottom-3 right-3">
|
||||
<label
|
||||
htmlFor={`toolToggle-${index}`}
|
||||
className="relative inline-block h-6 w-10 cursor-pointer rounded-full bg-gray-300 dark:bg-[#D2D5DA33]/20 transition [-webkit-tap-highlight-color:_transparent] has-[:checked]:bg-[#0C9D35CC] has-[:checked]:dark:bg-[#0C9D35CC]"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
id={`toolToggle-${index}`}
|
||||
className="peer sr-only"
|
||||
checked={tool.status}
|
||||
onChange={() =>
|
||||
updateToolStatus(tool.id, !tool.status)
|
||||
}
|
||||
/>
|
||||
<span className="absolute inset-y-0 start-0 m-[3px] size-[18px] rounded-full bg-white transition-all peer-checked:start-4"></span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<AddToolModal
|
||||
message="Select a tool to set up"
|
||||
modalState={addToolModalState}
|
||||
setModalState={setAddToolModalState}
|
||||
getUserTools={getUserTools}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -7,8 +7,8 @@ import SettingsBar from '../components/SettingsBar';
|
||||
import i18n from '../locale/i18n';
|
||||
import { Doc } from '../models/misc';
|
||||
import {
|
||||
selectSourceDocs,
|
||||
selectPaginatedDocuments,
|
||||
selectSourceDocs,
|
||||
setPaginatedDocuments,
|
||||
setSourceDocs,
|
||||
} from '../preferences/preferenceSlice';
|
||||
@@ -17,6 +17,7 @@ import APIKeys from './APIKeys';
|
||||
import Documents from './Documents';
|
||||
import General from './General';
|
||||
import Logs from './Logs';
|
||||
import Tools from './Tools';
|
||||
import Widgets from './Widgets';
|
||||
|
||||
export default function Settings() {
|
||||
@@ -100,6 +101,8 @@ export default function Settings() {
|
||||
return <Analytics />;
|
||||
case t('settings.logs.label'):
|
||||
return <Logs />;
|
||||
case t('settings.tools.label'):
|
||||
return <Tools />;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -18,3 +18,32 @@ export type LogData = {
|
||||
retriever_params: Record<string, any>;
|
||||
timestamp: string;
|
||||
};
|
||||
|
||||
export type UserTool = {
|
||||
id: string;
|
||||
name: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
status: boolean;
|
||||
config: {
|
||||
[key: string]: string;
|
||||
};
|
||||
actions: {
|
||||
name: string;
|
||||
description: string;
|
||||
parameters: {
|
||||
properties: {
|
||||
[key: string]: {
|
||||
type: string;
|
||||
description: string;
|
||||
filled_by_llm: boolean;
|
||||
value: string;
|
||||
};
|
||||
};
|
||||
additionalProperties: boolean;
|
||||
required: string[];
|
||||
type: string;
|
||||
};
|
||||
active: boolean;
|
||||
}[];
|
||||
};
|
||||
|
||||
5
mock-backend/.gitignore
vendored
5
mock-backend/.gitignore
vendored
@@ -1,5 +0,0 @@
|
||||
|
||||
# Elastic Beanstalk Files
|
||||
.elasticbeanstalk/*
|
||||
!.elasticbeanstalk/*.cfg.yml
|
||||
!.elasticbeanstalk/*.global.yml
|
||||
@@ -1,11 +0,0 @@
|
||||
FROM node:20.6.1-bullseye-slim
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
COPY package*.json ./
|
||||
RUN npm install
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD [ "npm", "run", "start"]
|
||||
1379
mock-backend/package-lock.json
generated
1379
mock-backend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"name": "mock-backend",
|
||||
"version": "1.0.0",
|
||||
"description": "",
|
||||
"main": "index.js",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"test": "echo \"Error: no test specified\" && exit 1",
|
||||
"start": "node src/server.js"
|
||||
},
|
||||
"keywords": [],
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"cors": "^2.8.5",
|
||||
"json-server": "^0.17.4",
|
||||
"uuid": "^9.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/json-server": "^0.14.5"
|
||||
}
|
||||
}
|
||||
@@ -1,244 +0,0 @@
|
||||
{
|
||||
"combine": [
|
||||
{
|
||||
"date": "default",
|
||||
"description": "default",
|
||||
"docLink": "default",
|
||||
"fullName": "default",
|
||||
"language": "default",
|
||||
"location": "local",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "default",
|
||||
"version": ""
|
||||
},
|
||||
{
|
||||
"date": "13/02/2023",
|
||||
"description": "Serverless Framework, the serverless application framework for building web, mobile and IoT applications on AWS Lambda, Azure Functions, Google CloudFunctions & more!",
|
||||
"docLink": "https://serverless.com/framework/docs/",
|
||||
"fullName": "Serverless Framework",
|
||||
"language": "serverless",
|
||||
"location": "remote",
|
||||
"name": "serverless framework",
|
||||
"version": "3.27.0"
|
||||
},
|
||||
{
|
||||
"date": "15/02/2023",
|
||||
"description": "Machine Learning in Python",
|
||||
"docLink": "https://scikit-learn.org/stable/",
|
||||
"fullName": "scikit-learn",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "scikit-learn",
|
||||
"version": "1.2.1"
|
||||
},
|
||||
{
|
||||
"date": "07/02/2023",
|
||||
"description": "Machine Learning in Python",
|
||||
"docLink": "https://scikit-learn.org/stable/",
|
||||
"fullName": "scikit-learn",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"name": "scikit-learn",
|
||||
"version": "1.2.1"
|
||||
},
|
||||
{
|
||||
"date": "07/02/2023",
|
||||
"description": "Pandas is alibrary providing high-performance, easy-to-use data structures and data analysis tools for the Python programming language.",
|
||||
"docLink": "https://pandas.pydata.org/docs/",
|
||||
"fullName": "Pandas",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "pandas",
|
||||
"version": "1.5.3"
|
||||
},
|
||||
{
|
||||
"date": "07/02/2023",
|
||||
"description": "Pandas is alibrary providing high-performance, easy-to-use data structures and data analysis tools for the Python programming language.",
|
||||
"docLink": "https://pandas.pydata.org/docs/",
|
||||
"fullName": "Pandas",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"name": "pandas",
|
||||
"version": "1.5.3"
|
||||
},
|
||||
{
|
||||
"date": "29/02/2023",
|
||||
"description": "Python is a programming language that lets you work quickly and integrate systems more effectively.",
|
||||
"docLink": "https://docs.python.org/3/",
|
||||
"fullName": "Python",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"model": "huggingface_sentence-transformers-all-mpnet-base-v2",
|
||||
"name": "python",
|
||||
"version": "3.11.1"
|
||||
},
|
||||
{
|
||||
"date": "15/02/2023",
|
||||
"description": "Python is a programming language that lets you work quickly and integrate systems more effectively.",
|
||||
"docLink": "https://docs.python.org/3/",
|
||||
"fullName": "Python",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "python",
|
||||
"version": "3.11.1"
|
||||
},
|
||||
{
|
||||
"date": "07/02/2023",
|
||||
"description": "Python is a programming language that lets you work quickly and integrate systems more effectively.",
|
||||
"docLink": "https://docs.python.org/3/",
|
||||
"fullName": "Python",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"name": "python",
|
||||
"version": "3.11.1"
|
||||
},
|
||||
{
|
||||
"date": "08/02/2023",
|
||||
"description": "GPT Index is a project consisting of a set of data structures designed to make it easier to use large external knowledge bases with LLMs.",
|
||||
"docLink": "https://gpt-index.readthedocs.io/en/latest/index.html",
|
||||
"fullName": "LangChain",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"name": "gpt-index",
|
||||
"version": "0.4.0"
|
||||
},
|
||||
{
|
||||
"date": "15/02/2023",
|
||||
"description": "Large language models (LLMs) are emerging as a transformative technology, enabling developers to build applications that they previously could not.",
|
||||
"docLink": "https://langchain.readthedocs.io/en/latest/index.html",
|
||||
"fullName": "LangChain",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "langchain",
|
||||
"version": "0.0.87"
|
||||
},
|
||||
{
|
||||
"date": "07/02/2023",
|
||||
"description": "Large language models (LLMs) are emerging as a transformative technology, enabling developers to build applications that they previously could not.",
|
||||
"docLink": "https://langchain.readthedocs.io/en/latest/index.html",
|
||||
"fullName": "LangChain",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"name": "langchain",
|
||||
"version": "0.0.79"
|
||||
},
|
||||
{
|
||||
"date": "13/03/2023",
|
||||
"description": "Large language models (LLMs) are emerging as a transformative technology, enabling developers to build applications that they previously could not.",
|
||||
"docLink": "https://langchain.readthedocs.io/en/latest/index.html",
|
||||
"fullName": "LangChain",
|
||||
"language": "python",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "langchain",
|
||||
"version": "0.0.109"
|
||||
},
|
||||
{
|
||||
"date": "16/03/2023",
|
||||
"description": "A JavaScript library for building user interfaces\nGet Started\n",
|
||||
"docLink": "https://reactjs.org/",
|
||||
"fullName": "React",
|
||||
"language": "javascript",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "react",
|
||||
"version": "v18.2.0"
|
||||
},
|
||||
{
|
||||
"date": "15/02/2023",
|
||||
"description": "is a lightweight, interpreted, or just-in-time compiled programming language with first-class functions.",
|
||||
"docLink": "https://developer.mozilla.org/en-US/docs/Web/JavaScript",
|
||||
"fullName": "JavaScript",
|
||||
"language": "javascript",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "javascript",
|
||||
"version": "ES2015"
|
||||
},
|
||||
{
|
||||
"date": "16/03/2023",
|
||||
"description": "An approachable, performant and versatile framework for building web user interfaces. ",
|
||||
"docLink": "https://vuejs.org/",
|
||||
"fullName": "Vue.js",
|
||||
"language": "javascript",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "vuejs",
|
||||
"version": "v3.3.0"
|
||||
},
|
||||
{
|
||||
"date": "16/03/2023",
|
||||
"description": "Get ready for a development environment that can finally catch up with you.",
|
||||
"docLink": "https://vitejs.dev/",
|
||||
"fullName": "Vite",
|
||||
"language": "javascript",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "vitejs",
|
||||
"version": "v4.2.0"
|
||||
},
|
||||
{
|
||||
"date": "15/02/2023",
|
||||
"description": "Solidity is an object-oriented, high-level language for implementing smart contracts.",
|
||||
"docLink": "https://docs.soliditylang.org/en/v0.8.18/",
|
||||
"fullName": "Solidity",
|
||||
"language": "ethereum",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "solidity",
|
||||
"version": "0.8.18"
|
||||
},
|
||||
{
|
||||
"date": "07/02/2023",
|
||||
"description": "Solidity is an object-oriented, high-level language for implementing smart contracts.",
|
||||
"docLink": "https://docs.soliditylang.org/en/v0.8.18/",
|
||||
"fullName": "Solidity",
|
||||
"language": "ethereum",
|
||||
"location": "remote",
|
||||
"name": "solidity",
|
||||
"version": "0.8.18"
|
||||
},
|
||||
{
|
||||
"date": "28/02/2023",
|
||||
"description": "GPT-powered chat for documentation search & assistance. ",
|
||||
"docLink": "https://github.com/arc53/DocsGPT/wiki",
|
||||
"fullName": "DocsGPT",
|
||||
"language": "docsgpt",
|
||||
"location": "remote",
|
||||
"model": "huggingface_sentence-transformers-all-mpnet-base-v2",
|
||||
"name": "docsgpt",
|
||||
"version": "0.1.0"
|
||||
},
|
||||
{
|
||||
"date": "28/02/2023",
|
||||
"description": "GPT-powered chat for documentation search & assistance. ",
|
||||
"docLink": "https://github.com/arc53/DocsGPT/wiki",
|
||||
"fullName": "DocsGPT",
|
||||
"language": "docsgpt",
|
||||
"location": "remote",
|
||||
"model": "openai_text-embedding-ada-002",
|
||||
"name": "docsgpt",
|
||||
"version": "0.1.0"
|
||||
}
|
||||
],
|
||||
"conversations": [
|
||||
{
|
||||
"id": "65cf39c936523eea21ebe117",
|
||||
"name": "Request clarification"
|
||||
},
|
||||
{
|
||||
"id": "65cf39ba36523eea21ebe116",
|
||||
"name": "Clarification request"
|
||||
},
|
||||
{
|
||||
"id": "65cf37e97d527c332bbac933",
|
||||
"name": "Greetings, assistance inquiry."
|
||||
}],
|
||||
"docs_check": {
|
||||
"status": "loaded"
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
{
|
||||
"/api/*": "/$1",
|
||||
"/get_conversations": "/conversations",
|
||||
"/get_single_conversation?id=:id": "/conversations/:id",
|
||||
"/delete_conversation?id=:id": "/conversations/:id",
|
||||
"/conversations?id=:id": "/conversations/:id"
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
import jsonServer from "json-server";
|
||||
import routes from "./mocks/routes.json" assert { type: "json" };
|
||||
import { v4 as uuid } from "uuid";
|
||||
import cors from 'cors'
|
||||
const server = jsonServer.create();
|
||||
const router = jsonServer.router("./src/mocks/db.json");
|
||||
const middlewares = jsonServer.defaults();
|
||||
|
||||
const localStorage = [];
|
||||
|
||||
server.use(middlewares);
|
||||
server.use(cors({ origin: ['*'] }))
|
||||
server.use(jsonServer.rewriter(routes));
|
||||
|
||||
server.use((req, res, next) => {
|
||||
if (req.method === "POST") {
|
||||
if (req.url.includes("/delete_conversation")) {
|
||||
req.method = "DELETE";
|
||||
} else if (req.url === "/upload") {
|
||||
const taskId = uuid();
|
||||
localStorage.push(taskId);
|
||||
}
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
router.render = (req, res) => {
|
||||
if (req.url === "/feedback") {
|
||||
res.status(200).jsonp({ status: "ok" });
|
||||
} else if (req.url === "/upload") {
|
||||
res.status(200).jsonp({
|
||||
status: "ok",
|
||||
task_id: localStorage[localStorage.length - 1],
|
||||
});
|
||||
} else if (req.url.includes("/task_status")) {
|
||||
const taskId = req.query["task_id"];
|
||||
const taskIdExists = localStorage.includes(taskId);
|
||||
if (taskIdExists) {
|
||||
res.status(200).jsonp({
|
||||
result: {
|
||||
directory: "temp",
|
||||
filename: "install.rst",
|
||||
formats: [".rst", ".md", ".pdf"],
|
||||
name_job: "somename",
|
||||
user: "local",
|
||||
},
|
||||
status: "SUCCESS",
|
||||
});
|
||||
} else {
|
||||
res.status(404).jsonp({});
|
||||
}
|
||||
} else if (req.url === "/stream" && req.method === "POST") {
|
||||
res.writeHead(200, {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive'
|
||||
});
|
||||
const message = ('Hi, How are you today?').split(' ');
|
||||
let index = 0;
|
||||
const interval = setInterval(() => {
|
||||
if (index < message.length) {
|
||||
res.write(`data: {"answer": "${message[index++]} "}\n`);
|
||||
} else {
|
||||
res.write(`data: {"type": "id", "id": "65cbc39d11f077b9eeb06d26"}\n`)
|
||||
res.write(`data: {"type": "end"}\n`)
|
||||
clearInterval(interval); // Stop the interval once the message is fully streamed
|
||||
res.end(); // End the response
|
||||
}
|
||||
}, 500); // Send a word every 1 second
|
||||
}
|
||||
else if (req.url === '/search' && req.method === 'POST') {
|
||||
res.status(200).json(
|
||||
[
|
||||
{
|
||||
"text": "\n\n/api/answer\nIt's a POST request that sends a JSON in body with 4 values. It will receive an answer for a user provided question.\n",
|
||||
"title": "API-docs.md"
|
||||
},
|
||||
{
|
||||
"text": "\n\nOur Standards\n\nExamples of behavior that contribute to a positive environment for our\ncommunity include:\n* Demonstrating empathy and kindness towards other people\n",
|
||||
"title": "How-to-use-different-LLM.md"
|
||||
}
|
||||
]
|
||||
)
|
||||
}
|
||||
else if (req.url === '/get_prompts' && req.method === 'GET') {
|
||||
res.status(200).json([
|
||||
{
|
||||
"id": "default",
|
||||
"name": "default",
|
||||
"type": "public"
|
||||
},
|
||||
{
|
||||
"id": "creative",
|
||||
"name": "creative",
|
||||
"type": "public"
|
||||
},
|
||||
{
|
||||
"id": "strict",
|
||||
"name": "strict",
|
||||
"type": "public"
|
||||
}
|
||||
]);
|
||||
}
|
||||
else if (req.url.startsWith('/get_single_prompt') && req.method==='GET') {
|
||||
const id = req.query.id;
|
||||
console.log('hre');
|
||||
if (id === 'creative')
|
||||
res.status(200).json({
|
||||
"content": "You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible."
|
||||
})
|
||||
else if (id === 'strict') {
|
||||
res.status(200).json({
|
||||
"content": "You are an AI Assistant, DocsGPT, adept at offering document assistance. \nYour expertise lies in providing answer on top of provided context."
|
||||
})
|
||||
}
|
||||
else {
|
||||
res.status(200).json({
|
||||
"content": "You are a helpful AI assistant, DocsGPT, specializing in document assistance, designed to offer detailed and informative responses."
|
||||
})
|
||||
}
|
||||
}
|
||||
else {
|
||||
res.status(res.statusCode).jsonp(res.locals.data);
|
||||
}
|
||||
};
|
||||
|
||||
server.use(router);
|
||||
|
||||
server.listen(8080, () => {
|
||||
console.log("JSON Server is running");
|
||||
});
|
||||
@@ -1,22 +0,0 @@
|
||||
dataclasses_json==0.6.3
|
||||
docx2txt==0.8
|
||||
EbookLib==0.18
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
faiss_cpu==1.7.4
|
||||
html2text==2020.1.16
|
||||
javalang==0.13.0
|
||||
langchain==0.2.10
|
||||
langchain_community==0.2.9
|
||||
langchain-openai==0.0.5
|
||||
nltk==3.9
|
||||
openapi3_parser==1.1.16
|
||||
pandas==2.2.0
|
||||
PyPDF2==3.0.1
|
||||
python-dotenv==1.0.1
|
||||
retry==0.9.2
|
||||
Sphinx==7.2.6
|
||||
tiktoken==0.5.2
|
||||
tqdm==4.66.3
|
||||
typer==0.9.0
|
||||
unstructured==0.12.2
|
||||
@@ -46,6 +46,7 @@ class TestAnthropicLLM(unittest.TestCase):
|
||||
{"content": "question"}
|
||||
]
|
||||
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
|
||||
mock_tools = Mock()
|
||||
|
||||
with patch("application.cache.get_redis_instance") as mock_make_redis:
|
||||
mock_redis_instance = mock_make_redis.return_value
|
||||
@@ -53,7 +54,7 @@ class TestAnthropicLLM(unittest.TestCase):
|
||||
mock_redis_instance.set = Mock()
|
||||
|
||||
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
|
||||
responses = list(self.llm.gen_stream("test_model", messages))
|
||||
responses = list(self.llm.gen_stream("test_model", messages, tools=mock_tools))
|
||||
self.assertListEqual(responses, ["response_1", "response_2"])
|
||||
|
||||
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
|
||||
|
||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
|
||||
return_value=self.response) as mock_invoke_endpoint:
|
||||
output = list(self.sagemaker.gen_stream(None, self.messages))
|
||||
output = list(self.sagemaker.gen_stream(None, self.messages, tools=None))
|
||||
mock_invoke_endpoint.assert_called_once_with(
|
||||
EndpointName=self.sagemaker.endpoint,
|
||||
ContentType='application/json',
|
||||
|
||||
@@ -12,18 +12,21 @@ def test_make_gen_cache_key():
|
||||
{'role': 'system', 'content': 'test_system_message'},
|
||||
]
|
||||
model = "test_docgpt"
|
||||
tools = None
|
||||
|
||||
# Manually calculate the expected hash
|
||||
expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}"
|
||||
messages_str = json.dumps(messages)
|
||||
tools_str = json.dumps(tools) if tools else ""
|
||||
expected_combined = f"{model}_{messages_str}_{tools_str}"
|
||||
expected_hash = get_hash(expected_combined)
|
||||
cache_key = gen_cache_key(*messages, model=model)
|
||||
cache_key = gen_cache_key(messages, model=model, tools=None)
|
||||
|
||||
assert cache_key == expected_hash
|
||||
|
||||
def test_gen_cache_key_invalid_message_format():
|
||||
# Test when messages is not a list
|
||||
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
|
||||
gen_cache_key("This is not a list", model="docgpt")
|
||||
gen_cache_key("This is not a list", model="docgpt", tools=None)
|
||||
assert str(context.exception) == "All messages must be dictionaries."
|
||||
|
||||
# Test for gen_cache decorator
|
||||
@@ -35,14 +38,14 @@ def test_gen_cache_hit(mock_make_redis):
|
||||
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit
|
||||
|
||||
@gen_cache
|
||||
def mock_function(self, model, messages):
|
||||
def mock_function(self, model, messages, stream, tools):
|
||||
return "new_result"
|
||||
|
||||
messages = [{'role': 'user', 'content': 'test_user_message'}]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = mock_function(None, model, messages)
|
||||
result = mock_function(None, model, messages, stream=False, tools=None)
|
||||
|
||||
# Assert
|
||||
assert result == "cached_result" # Should return cached result
|
||||
@@ -58,7 +61,7 @@ def test_gen_cache_miss(mock_make_redis):
|
||||
mock_redis_instance.get.return_value = None # Simulate a cache miss
|
||||
|
||||
@gen_cache
|
||||
def mock_function(self, model, messages):
|
||||
def mock_function(self, model, messages, steam, tools):
|
||||
return "new_result"
|
||||
|
||||
messages = [
|
||||
@@ -67,7 +70,7 @@ def test_gen_cache_miss(mock_make_redis):
|
||||
]
|
||||
model = "test_docgpt"
|
||||
# Act
|
||||
result = mock_function(None, model, messages)
|
||||
result = mock_function(None, model, messages, stream=False, tools=None)
|
||||
|
||||
# Assert
|
||||
assert result == "new_result"
|
||||
@@ -83,14 +86,14 @@ def test_stream_cache_hit(mock_make_redis):
|
||||
mock_redis_instance.get.return_value = cached_chunk
|
||||
|
||||
@stream_cache
|
||||
def mock_function(self, model, messages, stream):
|
||||
def mock_function(self, model, messages, stream, tools):
|
||||
yield "new_chunk"
|
||||
|
||||
messages = [{'role': 'user', 'content': 'test_user_message'}]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = list(mock_function(None, model, messages, stream=True))
|
||||
result = list(mock_function(None, model, messages, stream=True, tools=None))
|
||||
|
||||
# Assert
|
||||
assert result == ["chunk1", "chunk2"] # Should return cached chunks
|
||||
@@ -106,7 +109,7 @@ def test_stream_cache_miss(mock_make_redis):
|
||||
mock_redis_instance.get.return_value = None # Simulate a cache miss
|
||||
|
||||
@stream_cache
|
||||
def mock_function(self, model, messages, stream):
|
||||
def mock_function(self, model, messages, stream, tools):
|
||||
yield "new_chunk"
|
||||
|
||||
messages = [
|
||||
@@ -117,7 +120,7 @@ def test_stream_cache_miss(mock_make_redis):
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = list(mock_function(None, model, messages, stream=True))
|
||||
result = list(mock_function(None, model, messages, stream=True, tools=None))
|
||||
|
||||
# Assert
|
||||
assert result == ["new_chunk"]
|
||||
|
||||
Reference in New Issue
Block a user