mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge branch 'main' of https://github.com/manishmadan2882/docsgpt
This commit is contained in:
@@ -9,10 +9,21 @@ from application.llm.llm_creator import LLMCreator
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
llm_name,
|
||||
gpt_model,
|
||||
api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name, api_key=api_key, user_api_key=user_api_key
|
||||
llm_name,
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.llm_handler = get_llm_handler(llm_name)
|
||||
self.gpt_model = gpt_model
|
||||
|
||||
@@ -17,8 +17,12 @@ class ClassicAgent(BaseAgent):
|
||||
user_api_key=None,
|
||||
prompt="",
|
||||
chat_history=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
|
||||
super().__init__(
|
||||
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
|
||||
)
|
||||
self.user = decoded_token.get("sub")
|
||||
self.prompt = prompt
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
|
||||
@@ -73,7 +77,7 @@ class ClassicAgent(BaseAgent):
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": query})
|
||||
|
||||
tools_dict = self._get_user_tools()
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
resp = self._llm_gen(messages_combine, log_context)
|
||||
|
||||
@@ -124,6 +124,7 @@ def save_conversation(
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
index=None,
|
||||
api_key=None,
|
||||
):
|
||||
@@ -182,7 +183,7 @@ def save_conversation(
|
||||
|
||||
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
|
||||
conversation_data = {
|
||||
"user": "local",
|
||||
"user": decoded_token.get("sub"),
|
||||
"date": datetime.datetime.utcnow(),
|
||||
"name": completion,
|
||||
"queries": [
|
||||
@@ -223,6 +224,7 @@ def complete_stream(
|
||||
retriever,
|
||||
conversation_id,
|
||||
user_api_key,
|
||||
decoded_token,
|
||||
isNoneDoc=False,
|
||||
index=None,
|
||||
should_save_conversation=True,
|
||||
@@ -262,7 +264,10 @@ def complete_stream(
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -273,6 +278,7 @@ def complete_stream(
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
index,
|
||||
api_key=user_api_key,
|
||||
)
|
||||
@@ -288,7 +294,7 @@ def complete_stream(
|
||||
{
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
@@ -383,15 +389,21 @@ class Stream(Resource):
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}",
|
||||
@@ -411,6 +423,7 @@ class Stream(Resource):
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
@@ -422,6 +435,7 @@ class Stream(Resource):
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
return Response(
|
||||
@@ -431,6 +445,7 @@ class Stream(Resource):
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=index,
|
||||
should_save_conversation=save_conv,
|
||||
@@ -523,13 +538,21 @@ class Answer(Resource):
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
@@ -547,6 +570,7 @@ class Answer(Resource):
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
@@ -558,6 +582,7 @@ class Answer(Resource):
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
response_full = ""
|
||||
@@ -571,6 +596,7 @@ class Answer(Resource):
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=False,
|
||||
@@ -604,7 +630,10 @@ class Answer(Resource):
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
@@ -616,6 +645,7 @@ class Answer(Resource):
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
)
|
||||
)
|
||||
@@ -625,7 +655,7 @@ class Answer(Resource):
|
||||
{
|
||||
"action": "api_answer",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
@@ -694,12 +724,20 @@ class Search(Resource):
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
user_api_key = data["api_key"]
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
@@ -715,6 +753,7 @@ class Search(Resource):
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
docs = retriever.search(question)
|
||||
@@ -724,7 +763,7 @@ class Search(Resource):
|
||||
{
|
||||
"action": "api_search",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"sources": docs,
|
||||
|
||||
@@ -15,7 +15,6 @@ from werkzeug.utils import secure_filename
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
from application.api.user.tasks import ingest, ingest_remote
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.extensions import api
|
||||
@@ -68,6 +67,21 @@ def generate_date_range(start_date, end_date):
|
||||
}
|
||||
|
||||
|
||||
def get_vector_store(source_id):
|
||||
"""
|
||||
Get the Vector Store
|
||||
Args:
|
||||
source_id (str): source id of the document
|
||||
"""
|
||||
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
@user_ns.route("/api/delete_conversation")
|
||||
class DeleteConversation(Resource):
|
||||
@api.doc(
|
||||
@@ -75,6 +89,9 @@ class DeleteConversation(Resource):
|
||||
params={"id": "The ID of the conversation to delete"},
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
@@ -82,7 +99,9 @@ class DeleteConversation(Resource):
|
||||
)
|
||||
|
||||
try:
|
||||
conversations_collection.delete_one({"_id": ObjectId(conversation_id)})
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting conversation: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -95,7 +114,10 @@ class DeleteAllConversations(Resource):
|
||||
description="Deletes all conversations for a specific user",
|
||||
)
|
||||
def get(self):
|
||||
user_id = "local"
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.delete_many({"user": user_id})
|
||||
except Exception as err:
|
||||
@@ -110,11 +132,18 @@ class GetConversations(Resource):
|
||||
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
conversations = conversations_collection.find(
|
||||
{"api_key": {"$exists": False}}
|
||||
).sort("date", -1).limit(30)
|
||||
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{"api_key": {"$exists": False}, "user": decoded_token.get("sub")}
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{"id": str(conversation["_id"]), "name": conversation["name"]}
|
||||
for conversation in conversations
|
||||
@@ -132,6 +161,9 @@ class GetSingleConversation(Resource):
|
||||
params={"id": "The conversation ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
@@ -140,7 +172,7 @@ class GetSingleConversation(Resource):
|
||||
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
@@ -167,6 +199,9 @@ class UpdateConversationName(Resource):
|
||||
description="Updates the name of a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -175,7 +210,8 @@ class UpdateConversationName(Resource):
|
||||
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])}, {"$set": {"name": data["name"]}}
|
||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating conversation name: {err}")
|
||||
@@ -210,6 +246,9 @@ class SubmitFeedback(Resource):
|
||||
description="Submit feedback for a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["feedback", "conversation_id", "question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -222,12 +261,13 @@ class SubmitFeedback(Resource):
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$unset": {
|
||||
f"queries.{data['question_index']}.feedback": "",
|
||||
f"queries.{data['question_index']}.feedback_timestamp": ""
|
||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -236,12 +276,17 @@ class SubmitFeedback(Resource):
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data["feedback"],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(datetime.timezone.utc)
|
||||
f"queries.{data['question_index']}.feedback": data[
|
||||
"feedback"
|
||||
],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -284,13 +329,18 @@ class DeleteOldIndexes(Resource):
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
|
||||
doc = sources_collection.find_one({"_id": ObjectId(source_id), "user": "local"})
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
try:
|
||||
@@ -328,6 +378,9 @@ class UploadFile(Resource):
|
||||
description="Uploads a file to be vectorized and indexed",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
files = request.files.getlist("file")
|
||||
required_fields = ["user", "name"]
|
||||
@@ -343,7 +396,7 @@ class UploadFile(Resource):
|
||||
400,
|
||||
)
|
||||
|
||||
user = secure_filename(request.form["user"])
|
||||
user = secure_filename(decoded_token.get("sub"))
|
||||
job_name = secure_filename(request.form["name"])
|
||||
try:
|
||||
save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name)
|
||||
@@ -443,6 +496,9 @@ class UploadRemote(Resource):
|
||||
description="Uploads remote source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -463,7 +519,7 @@ class UploadRemote(Resource):
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=data["user"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -519,7 +575,10 @@ class RedirectToSources(Resource):
|
||||
class PaginatedSources(Resource):
|
||||
@api.doc(description="Get document with pagination, sorting and filtering")
|
||||
def get(self):
|
||||
user = "local"
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||
page = int(request.args.get("page", 1)) # Default to 1
|
||||
@@ -584,7 +643,10 @@ class PaginatedSources(Resource):
|
||||
class CombinedJson(Resource):
|
||||
@api.doc(description="Provide JSON file with combined available indexes")
|
||||
def get(self):
|
||||
user = "local"
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = [
|
||||
{
|
||||
"id": "default",
|
||||
@@ -688,13 +750,16 @@ class CreatePrompt(Resource):
|
||||
@api.expect(create_prompt_model)
|
||||
@api.doc(description="Create a new prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["content", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
user = "local"
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
|
||||
resp = prompts_collection.insert_one(
|
||||
@@ -716,7 +781,10 @@ class CreatePrompt(Resource):
|
||||
class GetPrompts(Resource):
|
||||
@api.doc(description="Get all prompts for the user")
|
||||
def get(self):
|
||||
user = "local"
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
list_prompts = [
|
||||
@@ -744,6 +812,10 @@ class GetPrompts(Resource):
|
||||
class GetSinglePrompt(Resource):
|
||||
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
prompt_id = request.args.get("id")
|
||||
if not prompt_id:
|
||||
return make_response(
|
||||
@@ -774,7 +846,9 @@ class GetSinglePrompt(Resource):
|
||||
chat_reduce_strict = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
||||
prompt = prompts_collection.find_one(
|
||||
{"_id": ObjectId(prompt_id), "user": user}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompt: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -792,6 +866,10 @@ class DeletePrompt(Resource):
|
||||
@api.expect(delete_prompt_model)
|
||||
@api.doc(description="Delete a prompt by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -799,7 +877,7 @@ class DeletePrompt(Resource):
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"])})
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -823,6 +901,10 @@ class UpdatePrompt(Resource):
|
||||
@api.expect(update_prompt_model)
|
||||
@api.doc(description="Update an existing prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name", "content"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -831,7 +913,7 @@ class UpdatePrompt(Resource):
|
||||
|
||||
try:
|
||||
prompts_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])},
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -845,7 +927,10 @@ class UpdatePrompt(Resource):
|
||||
class GetApiKeys(Resource):
|
||||
@api.doc(description="Retrieve API keys for the user")
|
||||
def get(self):
|
||||
user = "local"
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
keys = api_key_collection.find({"user": user})
|
||||
list_keys = []
|
||||
@@ -892,13 +977,16 @@ class CreateApiKey(Resource):
|
||||
@api.expect(create_api_key_model)
|
||||
@api.doc(description="Create a new API key")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["name", "prompt_id", "chunks"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
user = "local"
|
||||
try:
|
||||
key = str(uuid.uuid4())
|
||||
new_api_key = {
|
||||
@@ -932,6 +1020,10 @@ class DeleteApiKey(Resource):
|
||||
@api.expect(delete_api_key_model)
|
||||
@api.doc(description="Delete an API key by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -939,7 +1031,9 @@ class DeleteApiKey(Resource):
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
result = api_key_collection.delete_one({"_id": ObjectId(data["id"])})
|
||||
result = api_key_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return {"success": False, "message": "API Key not found"}, 404
|
||||
except Exception as err:
|
||||
@@ -966,6 +1060,10 @@ class ShareConversation(Resource):
|
||||
@api.expect(share_conversation_model)
|
||||
@api.doc(description="Share a conversation")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -977,8 +1075,6 @@ class ShareConversation(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "isPromptable is required"}), 400
|
||||
)
|
||||
|
||||
user = data.get("user", "local")
|
||||
conversation_id = data["conversation_id"]
|
||||
|
||||
try:
|
||||
@@ -1214,7 +1310,13 @@ class GetMessageAnalytics(Resource):
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=["last_hour", "last_24_hour", "last_7_days", "last_15_days", "last_30_days"],
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -1222,13 +1324,19 @@ class GetMessageAnalytics(Resource):
|
||||
@api.expect(get_message_analytics_model)
|
||||
@api.doc(description="Get message analytics based on filter option")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
api_key_collection.find_one(
|
||||
{"_id": ObjectId(api_key_id), "user": user}
|
||||
)["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
@@ -1247,9 +1355,9 @@ class GetMessageAnalytics(Resource):
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6 if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days"
|
||||
else 29
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days" else 29
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
@@ -1257,41 +1365,40 @@ class GetMessageAnalytics(Resource):
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user": user,
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
|
||||
pipeline = [
|
||||
# Initial match for API key if provided
|
||||
{
|
||||
"$match": {
|
||||
"api_key": api_key if api_key else {"$exists": False}
|
||||
}
|
||||
},
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
# Match queries within the time range
|
||||
{
|
||||
"$match": {
|
||||
"queries.timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date
|
||||
}
|
||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||
}
|
||||
},
|
||||
# Group by formatted timestamp
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.timestamp"
|
||||
"date": "$queries.timestamp",
|
||||
}
|
||||
},
|
||||
"count": {"$sum": 1}
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
# Sort by timestamp
|
||||
{"$sort": {"_id": 1}}
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
message_data = conversations_collection.aggregate(pipeline)
|
||||
@@ -1341,13 +1448,19 @@ class GetTokenAnalytics(Resource):
|
||||
@api.expect(get_token_analytics_model)
|
||||
@api.doc(description="Get token analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
api_key_collection.find_one(
|
||||
{"_id": ObjectId(api_key_id), "user": user}
|
||||
)["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
@@ -1429,13 +1542,12 @@ class GetTokenAnalytics(Resource):
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user_id": user,
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
else:
|
||||
match_stage["$match"]["api_key"] = {"$exists": False}
|
||||
|
||||
token_usage_data = token_usage_collection.aggregate(
|
||||
[
|
||||
@@ -1495,13 +1607,19 @@ class GetFeedbackAnalytics(Resource):
|
||||
@api.expect(get_feedback_analytics_model)
|
||||
@api.doc(description="Get feedback analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
api_key_collection.find_one(
|
||||
{"_id": ObjectId(api_key_id), "user": user}
|
||||
)["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
@@ -1514,11 +1632,21 @@ class GetFeedbackAnalytics(Resource):
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
@@ -1536,21 +1664,26 @@ class GetFeedbackAnalytics(Resource):
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"queries.feedback_timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
"queries.feedback": {"$exists": True}
|
||||
"queries.feedback_timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date,
|
||||
},
|
||||
"queries.feedback": {"$exists": True},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
else:
|
||||
match_stage["$match"]["api_key"] = {"$exists": False}
|
||||
|
||||
# Unwind the queries array to process each query separately
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
@@ -1637,6 +1770,10 @@ class GetUserLogs(Resource):
|
||||
@api.expect(get_user_logs_model)
|
||||
@api.doc(description="Get user logs with pagination")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
@@ -1653,7 +1790,7 @@ class GetUserLogs(Resource):
|
||||
current_app.logger.error(f"Error getting API key: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
query = {}
|
||||
query = {"user": user}
|
||||
if api_key:
|
||||
query = {"api_key": api_key}
|
||||
|
||||
@@ -1711,6 +1848,10 @@ class ManageSync(Resource):
|
||||
@api.expect(manage_sync_model)
|
||||
@api.doc(description="Manage sync frequency for sources")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["source_id", "sync_frequency"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -1730,7 +1871,7 @@ class ManageSync(Resource):
|
||||
sources_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(source_id),
|
||||
"user": "local",
|
||||
"user": user,
|
||||
},
|
||||
update_data,
|
||||
)
|
||||
@@ -1807,7 +1948,10 @@ class GetTools(Resource):
|
||||
@api.doc(description="Get tools created by a user")
|
||||
def get(self):
|
||||
try:
|
||||
user = "local"
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
tools = user_tools_collection.find({"user": user})
|
||||
user_tools = []
|
||||
for tool in tools:
|
||||
@@ -1850,6 +1994,10 @@ class CreateTool(Resource):
|
||||
)
|
||||
@api.doc(description="Create a new tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = [
|
||||
"name",
|
||||
@@ -1863,7 +2011,6 @@ class CreateTool(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
user = "local"
|
||||
transformed_actions = []
|
||||
for action in data["actions"]:
|
||||
action["active"] = True
|
||||
@@ -1914,6 +2061,10 @@ class UpdateTool(Resource):
|
||||
)
|
||||
@api.doc(description="Update a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -1949,7 +2100,7 @@ class UpdateTool(Resource):
|
||||
update_data["status"] = data["status"]
|
||||
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": "local"},
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": update_data},
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -1974,6 +2125,10 @@ class UpdateToolConfig(Resource):
|
||||
)
|
||||
@api.doc(description="Update the configuration of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -1982,7 +2137,7 @@ class UpdateToolConfig(Resource):
|
||||
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])},
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"config": data["config"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -2009,6 +2164,10 @@ class UpdateToolActions(Resource):
|
||||
)
|
||||
@api.doc(description="Update the actions of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "actions"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -2017,7 +2176,7 @@ class UpdateToolActions(Resource):
|
||||
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])},
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"actions": data["actions"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -2042,6 +2201,10 @@ class UpdateToolStatus(Resource):
|
||||
)
|
||||
@api.doc(description="Update the status of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "status"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -2050,7 +2213,7 @@ class UpdateToolStatus(Resource):
|
||||
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])},
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"status": data["status"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -2070,6 +2233,10 @@ class DeleteTool(Resource):
|
||||
)
|
||||
@api.doc(description="Delete a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -2077,7 +2244,9 @@ class DeleteTool(Resource):
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])})
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return {"success": False, "message": "Tool not found"}, 404
|
||||
except Exception as err:
|
||||
@@ -2087,21 +2256,6 @@ class DeleteTool(Resource):
|
||||
return {"success": True}, 200
|
||||
|
||||
|
||||
def get_vector_store(source_id):
|
||||
"""
|
||||
Get the Vector Store
|
||||
Args:
|
||||
source_id (str): source id of the document
|
||||
"""
|
||||
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
@user_ns.route("/api/get_chunks")
|
||||
class GetChunks(Resource):
|
||||
@api.doc(
|
||||
@@ -2109,6 +2263,10 @@ class GetChunks(Resource):
|
||||
params={"id": "The document ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
page = int(request.args.get("page", 1))
|
||||
per_page = int(request.args.get("per_page", 10))
|
||||
@@ -2116,6 +2274,12 @@ class GetChunks(Resource):
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunks = store.get_chunks()
|
||||
@@ -2160,6 +2324,10 @@ class AddChunk(Resource):
|
||||
description="Adds a new chunk to the document",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "text"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -2173,6 +2341,12 @@ class AddChunk(Resource):
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunk_id = store.add_chunk(text, metadata)
|
||||
@@ -2192,12 +2366,22 @@ class DeleteChunk(Resource):
|
||||
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
chunk_id = request.args.get("chunk_id")
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
@@ -2239,6 +2423,10 @@ class UpdateChunk(Resource):
|
||||
description="Updates an existing chunk in the document.",
|
||||
)
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "chunk_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
@@ -2253,6 +2441,12 @@ class UpdateChunk(Resource):
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunks = store.get_chunks()
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import os
|
||||
import platform
|
||||
import uuid
|
||||
|
||||
import dotenv
|
||||
from flask import Flask, redirect, request
|
||||
from flask import Flask, jsonify, redirect, request
|
||||
from jose import jwt
|
||||
|
||||
from application.auth import handle_auth
|
||||
|
||||
from application.core.logging_config import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
from application.api.answer.routes import answer # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.celery_init import celery # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.extensions import api # noqa: E402
|
||||
from application.api.answer.routes import answer # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.celery_init import celery # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.extensions import api # noqa: E402
|
||||
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import pathlib
|
||||
|
||||
pathlib.PosixPath = pathlib.WindowsPath
|
||||
|
||||
dotenv.load_dotenv()
|
||||
@@ -32,6 +40,25 @@ app.config.update(
|
||||
celery.config_from_object("application.celeryconfig")
|
||||
api.init_app(app)
|
||||
|
||||
if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECRET_KEY:
|
||||
key_file = ".jwt_secret_key"
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
settings.JWT_SECRET_KEY = f.read().strip()
|
||||
except FileNotFoundError:
|
||||
new_key = os.urandom(32).hex()
|
||||
with open(key_file, "w") as f:
|
||||
f.write(new_key)
|
||||
settings.JWT_SECRET_KEY = new_key
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
|
||||
|
||||
SIMPLE_JWT_TOKEN = None
|
||||
if settings.AUTH_TYPE == "simple_jwt":
|
||||
payload = {"sub": "local"}
|
||||
SIMPLE_JWT_TOKEN = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm="HS256")
|
||||
print(f"Generated Simple JWT Token: {SIMPLE_JWT_TOKEN}")
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
@@ -41,11 +68,47 @@ def home():
|
||||
return "Welcome to DocsGPT Backend!"
|
||||
|
||||
|
||||
@app.route("/api/config")
|
||||
def get_config():
|
||||
response = {
|
||||
"auth_type": settings.AUTH_TYPE,
|
||||
"requires_auth": settings.AUTH_TYPE in ["simple_jwt", "session_jwt"],
|
||||
}
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@app.route("/api/generate_token")
|
||||
def generate_token():
|
||||
if settings.AUTH_TYPE == "session_jwt":
|
||||
new_user_id = str(uuid.uuid4())
|
||||
token = jwt.encode(
|
||||
{"sub": new_user_id}, settings.JWT_SECRET_KEY, algorithm="HS256"
|
||||
)
|
||||
return jsonify({"token": token})
|
||||
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
||||
|
||||
|
||||
@app.before_request
|
||||
def authenticate_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
|
||||
decoded_token = handle_auth(request)
|
||||
if not decoded_token:
|
||||
request.decoded_token = None
|
||||
elif "error" in decoded_token:
|
||||
return jsonify(decoded_token), 401
|
||||
else:
|
||||
request.decoded_token = decoded_token
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
response.headers.add("Access-Control-Allow-Origin", "*")
|
||||
response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization")
|
||||
response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS")
|
||||
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
response.headers.add(
|
||||
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
|
||||
28
application/auth.py
Normal file
28
application/auth.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from jose import jwt
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def handle_auth(request, data={}):
|
||||
if settings.AUTH_TYPE in ["simple_jwt", "session_jwt"]:
|
||||
jwt_token = request.headers.get("Authorization")
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
jwt_token = jwt_token.replace("Bearer ", "")
|
||||
|
||||
try:
|
||||
decoded_token = jwt.decode(
|
||||
jwt_token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=["HS256"],
|
||||
options={"verify_exp": False},
|
||||
)
|
||||
return decoded_token
|
||||
except Exception as e:
|
||||
return {
|
||||
"message": f"Authentication error: {str(e)}",
|
||||
"error": "invalid_token",
|
||||
}
|
||||
else:
|
||||
return {"sub": "local"}
|
||||
@@ -10,6 +10,7 @@ current_dir = os.path.dirname(
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
AUTH_TYPE: Optional[str] = None
|
||||
LLM_NAME: str = "docsgpt"
|
||||
MODEL_NAME: Optional[str] = (
|
||||
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
@@ -98,6 +99,8 @@ class Settings(BaseSettings):
|
||||
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
|
||||
JWT_SECRET_KEY: str = ""
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -5,7 +5,8 @@ from application.usage import gen_token_usage, stream_token_usage
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
def __init__(self):
|
||||
def __init__(self, decoded_token=None):
|
||||
self.decoded_token = decoded_token
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
def _apply_decorator(self, method, decorators, *args, **kwargs):
|
||||
|
||||
@@ -9,6 +9,7 @@ from application.llm.premai import PremAILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.novita import NovitaLLM
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
llms = {
|
||||
"openai": OpenAILLM,
|
||||
@@ -21,12 +22,14 @@ class LLMCreator:
|
||||
"premai": PremAILLM,
|
||||
"groq": GroqLLM,
|
||||
"google": GoogleLLM,
|
||||
"novita": NovitaLLM
|
||||
"novita": NovitaLLM,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
|
||||
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
return llm_class(api_key, user_api_key, *args, **kwargs)
|
||||
return llm_class(
|
||||
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
|
||||
)
|
||||
|
||||
@@ -69,6 +69,7 @@ pymongo==4.10.1
|
||||
pypdf==5.2.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
python-jose==3.4.0
|
||||
python-pptx==1.0.2
|
||||
qdrant-client==1.13.2
|
||||
redis==5.2.1
|
||||
|
||||
@@ -17,6 +17,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = question
|
||||
self.source = source
|
||||
@@ -35,6 +36,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
@@ -81,7 +83,10 @@ class BraveRetSearch(BaseRetriever):
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
@@ -100,5 +105,5 @@ class BraveRetSearch(BaseRetriever):
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ class ClassicRAG(BaseRetriever):
|
||||
user_api_key=None,
|
||||
llm_name=settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.original_question = ""
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
@@ -37,10 +38,14 @@ class ClassicRAG(BaseRetriever):
|
||||
self.llm_name = llm_name
|
||||
self.api_key = api_key
|
||||
self.llm = LLMCreator.create_llm(
|
||||
self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key
|
||||
self.llm_name,
|
||||
api_key=self.api_key,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.question = self._rephrase_query()
|
||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _rephrase_query(self):
|
||||
if (
|
||||
|
||||
@@ -17,6 +17,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = question
|
||||
self.source = source
|
||||
@@ -35,6 +36,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _parse_lang_string(self, input_string):
|
||||
result = []
|
||||
@@ -88,17 +90,20 @@ class DuckDuckSearch(BaseRetriever):
|
||||
for doc in docs:
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 0:
|
||||
if len(self.chat_history) > 0:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
@@ -107,7 +112,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
|
||||
def get_params(self):
|
||||
return {
|
||||
"question": self.question,
|
||||
@@ -117,5 +122,5 @@ class DuckDuckSearch(BaseRetriever):
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
|
||||
@@ -9,10 +9,15 @@ db = mongo["docsgpt"]
|
||||
usage_collection = db["token_usage"]
|
||||
|
||||
|
||||
def update_token_usage(user_api_key, token_usage):
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
if decoded_token:
|
||||
user_id = decoded_token["sub"]
|
||||
else:
|
||||
user_id = None
|
||||
usage_data = {
|
||||
"user_id": user_id,
|
||||
"api_key": user_api_key,
|
||||
"prompt_tokens": token_usage["prompt_tokens"],
|
||||
"generated_tokens": token_usage["generated_tokens"],
|
||||
@@ -35,7 +40,7 @@ def gen_token_usage(func):
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
|
||||
result
|
||||
)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -54,6 +59,6 @@ def stream_token_usage(func):
|
||||
yield r
|
||||
for line in batch:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user