diff --git a/application/agents/base.py b/application/agents/base.py
index 7e36c991..d0f972a9 100644
--- a/application/agents/base.py
+++ b/application/agents/base.py
@@ -9,10 +9,21 @@ from application.llm.llm_creator import LLMCreator
class BaseAgent:
- def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None):
+ def __init__(
+ self,
+ endpoint,
+ llm_name,
+ gpt_model,
+ api_key,
+ user_api_key=None,
+ decoded_token=None,
+ ):
self.endpoint = endpoint
self.llm = LLMCreator.create_llm(
- llm_name, api_key=api_key, user_api_key=user_api_key
+ llm_name,
+ api_key=api_key,
+ user_api_key=user_api_key,
+ decoded_token=decoded_token,
)
self.llm_handler = get_llm_handler(llm_name)
self.gpt_model = gpt_model
diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py
index 8848c6f6..2752c833 100644
--- a/application/agents/classic_agent.py
+++ b/application/agents/classic_agent.py
@@ -17,8 +17,12 @@ class ClassicAgent(BaseAgent):
user_api_key=None,
prompt="",
chat_history=None,
+ decoded_token=None,
):
- super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
+ super().__init__(
+ endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
+ )
+ self.user = decoded_token.get("sub")
self.prompt = prompt
self.chat_history = chat_history if chat_history is not None else []
@@ -73,7 +77,7 @@ class ClassicAgent(BaseAgent):
)
messages_combine.append({"role": "user", "content": query})
- tools_dict = self._get_user_tools()
+ tools_dict = self._get_user_tools(self.user)
self._prepare_tools(tools_dict)
resp = self._llm_gen(messages_combine, log_context)
diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py
index 7f88ba0f..34081784 100644
--- a/application/api/answer/routes.py
+++ b/application/api/answer/routes.py
@@ -124,6 +124,7 @@ def save_conversation(
source_log_docs,
tool_calls,
llm,
+ decoded_token,
index=None,
api_key=None,
):
@@ -182,7 +183,7 @@ def save_conversation(
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_data = {
- "user": "local",
+ "user": decoded_token.get("sub"),
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
@@ -223,6 +224,7 @@ def complete_stream(
retriever,
conversation_id,
user_api_key,
+ decoded_token,
isNoneDoc=False,
index=None,
should_save_conversation=True,
@@ -262,7 +264,10 @@ def complete_stream(
doc["source"] = "None"
llm = LLMCreator.create_llm(
- settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
+ settings.LLM_NAME,
+ api_key=settings.API_KEY,
+ user_api_key=user_api_key,
+ decoded_token=decoded_token,
)
if should_save_conversation:
@@ -273,6 +278,7 @@ def complete_stream(
source_log_docs,
tool_calls,
llm,
+ decoded_token,
index,
api_key=user_api_key,
)
@@ -288,7 +294,7 @@ def complete_stream(
{
"action": "stream_answer",
"level": "info",
- "user": "local",
+ "user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
@@ -383,15 +389,21 @@ class Stream(Resource):
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
+ decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
+ decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
+ decoded_token = request.decoded_token
+
+ if not decoded_token:
+ return make_response({"error": "Unauthorized"}, 401)
logger.info(
f"/stream - request_data: {data}, source: {source}",
@@ -411,6 +423,7 @@ class Stream(Resource):
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
+ decoded_token=decoded_token,
)
retriever = RetrieverCreator.create_retriever(
@@ -422,6 +435,7 @@ class Stream(Resource):
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
+ decoded_token=decoded_token,
)
return Response(
@@ -431,6 +445,7 @@ class Stream(Resource):
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
+ decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=index,
should_save_conversation=save_conv,
@@ -523,13 +538,21 @@ class Answer(Resource):
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
+ decoded_token = {"sub": data_key.get("user")}
+
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
+ decoded_token = request.decoded_token
+
else:
source = {}
user_api_key = None
+ decoded_token = request.decoded_token
+
+ if not decoded_token:
+ return make_response({"error": "Unauthorized"}, 401)
prompt = get_prompt(prompt_id)
@@ -547,6 +570,7 @@ class Answer(Resource):
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
+ decoded_token=decoded_token,
)
retriever = RetrieverCreator.create_retriever(
@@ -558,6 +582,7 @@ class Answer(Resource):
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
+ decoded_token=decoded_token,
)
response_full = ""
@@ -571,6 +596,7 @@ class Answer(Resource):
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
+ decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=False,
@@ -604,7 +630,10 @@ class Answer(Resource):
doc["source"] = "None"
llm = LLMCreator.create_llm(
- settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
+ settings.LLM_NAME,
+ api_key=settings.API_KEY,
+ user_api_key=user_api_key,
+ decoded_token=decoded_token,
)
result = {"answer": response_full, "sources": source_log_docs}
@@ -616,6 +645,7 @@ class Answer(Resource):
source_log_docs,
tool_calls,
llm,
+ decoded_token,
api_key=user_api_key,
)
)
@@ -625,7 +655,7 @@ class Answer(Resource):
{
"action": "api_answer",
"level": "info",
- "user": "local",
+ "user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
@@ -694,12 +724,20 @@ class Search(Resource):
chunks = int(data_key.get("chunks", 2))
source = {"active_docs": data_key.get("source")}
user_api_key = data["api_key"]
+ decoded_token = {"sub": data_key.get("user")}
+
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
+ decoded_token = request.decoded_token
+
else:
source = {}
user_api_key = None
+ decoded_token = request.decoded_token
+
+ if not decoded_token:
+ return make_response({"error": "Unauthorized"}, 401)
logger.info(
f"/api/answer - request_data: {data}, source: {source}",
@@ -715,6 +753,7 @@ class Search(Resource):
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
+ decoded_token=decoded_token,
)
docs = retriever.search(question)
@@ -724,7 +763,7 @@ class Search(Resource):
{
"action": "api_search",
"level": "info",
- "user": "local",
+ "user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"sources": docs,
diff --git a/application/api/user/routes.py b/application/api/user/routes.py
index d7fb4d89..f3599c7e 100644
--- a/application/api/user/routes.py
+++ b/application/api/user/routes.py
@@ -15,7 +15,6 @@ from werkzeug.utils import secure_filename
from application.agents.tools.tool_manager import ToolManager
from application.api.user.tasks import ingest, ingest_remote
-
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.extensions import api
@@ -68,6 +67,21 @@ def generate_date_range(start_date, end_date):
}
+def get_vector_store(source_id):
+ """
+ Get the Vector Store
+ Args:
+ source_id (str): source id of the document
+ """
+
+ store = VectorCreator.create_vectorstore(
+ settings.VECTOR_STORE,
+ source_id=source_id,
+ embeddings_key=os.getenv("EMBEDDINGS_KEY"),
+ )
+ return store
+
+
@user_ns.route("/api/delete_conversation")
class DeleteConversation(Resource):
@api.doc(
@@ -75,6 +89,9 @@ class DeleteConversation(Resource):
params={"id": "The ID of the conversation to delete"},
)
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
@@ -82,7 +99,9 @@ class DeleteConversation(Resource):
)
try:
- conversations_collection.delete_one({"_id": ObjectId(conversation_id)})
+ conversations_collection.delete_one(
+ {"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
+ )
except Exception as err:
current_app.logger.error(f"Error deleting conversation: {err}")
return make_response(jsonify({"success": False}), 400)
@@ -95,7 +114,10 @@ class DeleteAllConversations(Resource):
description="Deletes all conversations for a specific user",
)
def get(self):
- user_id = "local"
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user_id = decoded_token.get("sub")
try:
conversations_collection.delete_many({"user": user_id})
except Exception as err:
@@ -110,11 +132,18 @@ class GetConversations(Resource):
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
)
def get(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
try:
- conversations = conversations_collection.find(
- {"api_key": {"$exists": False}}
- ).sort("date", -1).limit(30)
-
+ conversations = (
+ conversations_collection.find(
+ {"api_key": {"$exists": False}, "user": decoded_token.get("sub")}
+ )
+ .sort("date", -1)
+ .limit(30)
+ )
+
list_conversations = [
{"id": str(conversation["_id"]), "name": conversation["name"]}
for conversation in conversations
@@ -132,6 +161,9 @@ class GetSingleConversation(Resource):
params={"id": "The conversation ID"},
)
def get(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
@@ -140,7 +172,7 @@ class GetSingleConversation(Resource):
try:
conversation = conversations_collection.find_one(
- {"_id": ObjectId(conversation_id)}
+ {"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
@@ -167,6 +199,9 @@ class UpdateConversationName(Resource):
description="Updates the name of a conversation",
)
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["id", "name"]
missing_fields = check_required_fields(data, required_fields)
@@ -175,7 +210,8 @@ class UpdateConversationName(Resource):
try:
conversations_collection.update_one(
- {"_id": ObjectId(data["id"])}, {"$set": {"name": data["name"]}}
+ {"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
+ {"$set": {"name": data["name"]}},
)
except Exception as err:
current_app.logger.error(f"Error updating conversation name: {err}")
@@ -210,6 +246,9 @@ class SubmitFeedback(Resource):
description="Submit feedback for a conversation",
)
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["feedback", "conversation_id", "question_index"]
missing_fields = check_required_fields(data, required_fields)
@@ -222,12 +261,13 @@ class SubmitFeedback(Resource):
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
+ "user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$unset": {
f"queries.{data['question_index']}.feedback": "",
- f"queries.{data['question_index']}.feedback_timestamp": ""
+ f"queries.{data['question_index']}.feedback_timestamp": "",
}
},
)
@@ -236,12 +276,17 @@ class SubmitFeedback(Resource):
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
+ "user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$set": {
- f"queries.{data['question_index']}.feedback": data["feedback"],
- f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(datetime.timezone.utc)
+ f"queries.{data['question_index']}.feedback": data[
+ "feedback"
+ ],
+ f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
+ datetime.timezone.utc
+ ),
}
},
)
@@ -284,13 +329,18 @@ class DeleteOldIndexes(Resource):
params={"source_id": "The source ID to delete"},
)
def get(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
source_id = request.args.get("source_id")
if not source_id:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
- doc = sources_collection.find_one({"_id": ObjectId(source_id), "user": "local"})
+ doc = sources_collection.find_one(
+ {"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
+ )
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
try:
@@ -328,6 +378,9 @@ class UploadFile(Resource):
description="Uploads a file to be vectorized and indexed",
)
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
data = request.form
files = request.files.getlist("file")
required_fields = ["user", "name"]
@@ -343,7 +396,7 @@ class UploadFile(Resource):
400,
)
- user = secure_filename(request.form["user"])
+ user = secure_filename(decoded_token.get("sub"))
job_name = secure_filename(request.form["name"])
try:
save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name)
@@ -443,6 +496,9 @@ class UploadRemote(Resource):
description="Uploads remote source for vectorization",
)
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
data = request.form
required_fields = ["user", "source", "name", "data"]
missing_fields = check_required_fields(data, required_fields)
@@ -463,7 +519,7 @@ class UploadRemote(Resource):
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
- user=data["user"],
+ user=decoded_token.get("sub"),
loader=data["source"],
)
except Exception as err:
@@ -519,7 +575,10 @@ class RedirectToSources(Resource):
class PaginatedSources(Resource):
@api.doc(description="Get document with pagination, sorting and filtering")
def get(self):
- user = "local"
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
sort_field = request.args.get("sort", "date") # Default to 'date'
sort_order = request.args.get("order", "desc") # Default to 'desc'
page = int(request.args.get("page", 1)) # Default to 1
@@ -584,7 +643,10 @@ class PaginatedSources(Resource):
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
def get(self):
- user = "local"
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = [
{
"name": "Default",
@@ -685,13 +747,16 @@ class CreatePrompt(Resource):
@api.expect(create_prompt_model)
@api.doc(description="Create a new prompt")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["content", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
- user = "local"
+ user = decoded_token.get("sub")
try:
resp = prompts_collection.insert_one(
@@ -713,7 +778,10 @@ class CreatePrompt(Resource):
class GetPrompts(Resource):
@api.doc(description="Get all prompts for the user")
def get(self):
- user = "local"
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
try:
prompts = prompts_collection.find({"user": user})
list_prompts = [
@@ -741,6 +809,10 @@ class GetPrompts(Resource):
class GetSinglePrompt(Resource):
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
def get(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
prompt_id = request.args.get("id")
if not prompt_id:
return make_response(
@@ -771,7 +843,9 @@ class GetSinglePrompt(Resource):
chat_reduce_strict = f.read()
return make_response(jsonify({"content": chat_reduce_strict}), 200)
- prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
+ prompt = prompts_collection.find_one(
+ {"_id": ObjectId(prompt_id), "user": user}
+ )
except Exception as err:
current_app.logger.error(f"Error retrieving prompt: {err}")
return make_response(jsonify({"success": False}), 400)
@@ -789,6 +863,10 @@ class DeletePrompt(Resource):
@api.expect(delete_prompt_model)
@api.doc(description="Delete a prompt by ID")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
@@ -796,7 +874,7 @@ class DeletePrompt(Resource):
return missing_fields
try:
- prompts_collection.delete_one({"_id": ObjectId(data["id"])})
+ prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
except Exception as err:
current_app.logger.error(f"Error deleting prompt: {err}")
return make_response(jsonify({"success": False}), 400)
@@ -820,6 +898,10 @@ class UpdatePrompt(Resource):
@api.expect(update_prompt_model)
@api.doc(description="Update an existing prompt")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "name", "content"]
missing_fields = check_required_fields(data, required_fields)
@@ -828,7 +910,7 @@ class UpdatePrompt(Resource):
try:
prompts_collection.update_one(
- {"_id": ObjectId(data["id"])},
+ {"_id": ObjectId(data["id"]), "user": user},
{"$set": {"name": data["name"], "content": data["content"]}},
)
except Exception as err:
@@ -842,7 +924,10 @@ class UpdatePrompt(Resource):
class GetApiKeys(Resource):
@api.doc(description="Retrieve API keys for the user")
def get(self):
- user = "local"
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
try:
keys = api_key_collection.find({"user": user})
list_keys = []
@@ -889,13 +974,16 @@ class CreateApiKey(Resource):
@api.expect(create_api_key_model)
@api.doc(description="Create a new API key")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["name", "prompt_id", "chunks"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
- user = "local"
try:
key = str(uuid.uuid4())
new_api_key = {
@@ -929,6 +1017,10 @@ class DeleteApiKey(Resource):
@api.expect(delete_api_key_model)
@api.doc(description="Delete an API key by ID")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
@@ -936,7 +1028,9 @@ class DeleteApiKey(Resource):
return missing_fields
try:
- result = api_key_collection.delete_one({"_id": ObjectId(data["id"])})
+ result = api_key_collection.delete_one(
+ {"_id": ObjectId(data["id"]), "user": user}
+ )
if result.deleted_count == 0:
return {"success": False, "message": "API Key not found"}, 404
except Exception as err:
@@ -963,6 +1057,10 @@ class ShareConversation(Resource):
@api.expect(share_conversation_model)
@api.doc(description="Share a conversation")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["conversation_id"]
missing_fields = check_required_fields(data, required_fields)
@@ -974,8 +1072,6 @@ class ShareConversation(Resource):
return make_response(
jsonify({"success": False, "message": "isPromptable is required"}), 400
)
-
- user = data.get("user", "local")
conversation_id = data["conversation_id"]
try:
@@ -1211,7 +1307,13 @@ class GetMessageAnalytics(Resource):
required=False,
description="Filter option for analytics",
default="last_30_days",
- enum=["last_hour", "last_24_hour", "last_7_days", "last_15_days", "last_30_days"],
+ enum=[
+ "last_hour",
+ "last_24_hour",
+ "last_7_days",
+ "last_15_days",
+ "last_30_days",
+ ],
),
},
)
@@ -1219,13 +1321,19 @@ class GetMessageAnalytics(Resource):
@api.expect(get_message_analytics_model)
@api.doc(description="Get message analytics based on filter option")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
- api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
+ api_key_collection.find_one(
+ {"_id": ObjectId(api_key_id), "user": user}
+ )["key"]
if api_key_id
else None
)
@@ -1244,9 +1352,9 @@ class GetMessageAnalytics(Resource):
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
- 6 if filter_option == "last_7_days"
- else 14 if filter_option == "last_15_days"
- else 29
+ 6
+ if filter_option == "last_7_days"
+ else 14 if filter_option == "last_15_days" else 29
)
else:
return make_response(
@@ -1254,41 +1362,40 @@ class GetMessageAnalytics(Resource):
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
- end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
+ end_date = end_date.replace(
+ hour=23, minute=59, second=59, microsecond=999999
+ )
group_format = "%Y-%m-%d"
try:
+ match_stage = {
+ "$match": {
+ "user": user,
+ }
+ }
+ if api_key:
+ match_stage["$match"]["api_key"] = api_key
+
pipeline = [
- # Initial match for API key if provided
- {
- "$match": {
- "api_key": api_key if api_key else {"$exists": False}
- }
- },
+ match_stage,
{"$unwind": "$queries"},
- # Match queries within the time range
{
"$match": {
- "queries.timestamp": {
- "$gte": start_date,
- "$lte": end_date
- }
+ "queries.timestamp": {"$gte": start_date, "$lte": end_date}
}
},
- # Group by formatted timestamp
{
"$group": {
"_id": {
"$dateToString": {
"format": group_format,
- "date": "$queries.timestamp"
+ "date": "$queries.timestamp",
}
},
- "count": {"$sum": 1}
+ "count": {"$sum": 1},
}
},
- # Sort by timestamp
- {"$sort": {"_id": 1}}
+ {"$sort": {"_id": 1}},
]
message_data = conversations_collection.aggregate(pipeline)
@@ -1338,13 +1445,19 @@ class GetTokenAnalytics(Resource):
@api.expect(get_token_analytics_model)
@api.doc(description="Get token analytics data")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
- api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
+ api_key_collection.find_one(
+ {"_id": ObjectId(api_key_id), "user": user}
+ )["key"]
if api_key_id
else None
)
@@ -1426,13 +1539,12 @@ class GetTokenAnalytics(Resource):
try:
match_stage = {
"$match": {
+ "user_id": user,
"timestamp": {"$gte": start_date, "$lte": end_date},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
- else:
- match_stage["$match"]["api_key"] = {"$exists": False}
token_usage_data = token_usage_collection.aggregate(
[
@@ -1492,13 +1604,19 @@ class GetFeedbackAnalytics(Resource):
@api.expect(get_feedback_analytics_model)
@api.doc(description="Get feedback analytics data")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
- api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
+ api_key_collection.find_one(
+ {"_id": ObjectId(api_key_id), "user": user}
+ )["key"]
if api_key_id
else None
)
@@ -1511,11 +1629,21 @@ class GetFeedbackAnalytics(Resource):
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
- date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
+ date_field = {
+ "$dateToString": {
+ "format": group_format,
+ "date": "$queries.feedback_timestamp",
+ }
+ }
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
- date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
+ date_field = {
+ "$dateToString": {
+ "format": group_format,
+ "date": "$queries.feedback_timestamp",
+ }
+ }
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
@@ -1533,21 +1661,26 @@ class GetFeedbackAnalytics(Resource):
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
- date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
+ date_field = {
+ "$dateToString": {
+ "format": group_format,
+ "date": "$queries.feedback_timestamp",
+ }
+ }
try:
match_stage = {
"$match": {
- "queries.feedback_timestamp": {"$gte": start_date, "$lte": end_date},
- "queries.feedback": {"$exists": True}
+ "queries.feedback_timestamp": {
+ "$gte": start_date,
+ "$lte": end_date,
+ },
+ "queries.feedback": {"$exists": True},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
- else:
- match_stage["$match"]["api_key"] = {"$exists": False}
- # Unwind the queries array to process each query separately
pipeline = [
match_stage,
{"$unwind": "$queries"},
@@ -1634,6 +1767,10 @@ class GetUserLogs(Resource):
@api.expect(get_user_logs_model)
@api.doc(description="Get user logs with pagination")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
page = int(data.get("page", 1))
api_key_id = data.get("api_key_id")
@@ -1650,7 +1787,7 @@ class GetUserLogs(Resource):
current_app.logger.error(f"Error getting API key: {err}")
return make_response(jsonify({"success": False}), 400)
- query = {}
+ query = {"user": user}
if api_key:
query = {"api_key": api_key}
@@ -1708,6 +1845,10 @@ class ManageSync(Resource):
@api.expect(manage_sync_model)
@api.doc(description="Manage sync frequency for sources")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["source_id", "sync_frequency"]
missing_fields = check_required_fields(data, required_fields)
@@ -1727,7 +1868,7 @@ class ManageSync(Resource):
sources_collection.update_one(
{
"_id": ObjectId(source_id),
- "user": "local",
+ "user": user,
},
update_data,
)
@@ -1804,7 +1945,10 @@ class GetTools(Resource):
@api.doc(description="Get tools created by a user")
def get(self):
try:
- user = "local"
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
tools = user_tools_collection.find({"user": user})
user_tools = []
for tool in tools:
@@ -1847,6 +1991,10 @@ class CreateTool(Resource):
)
@api.doc(description="Create a new tool")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = [
"name",
@@ -1860,7 +2008,6 @@ class CreateTool(Resource):
if missing_fields:
return missing_fields
- user = "local"
transformed_actions = []
for action in data["actions"]:
action["active"] = True
@@ -1911,6 +2058,10 @@ class UpdateTool(Resource):
)
@api.doc(description="Update a tool by ID")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
@@ -1946,7 +2097,7 @@ class UpdateTool(Resource):
update_data["status"] = data["status"]
user_tools_collection.update_one(
- {"_id": ObjectId(data["id"]), "user": "local"},
+ {"_id": ObjectId(data["id"]), "user": user},
{"$set": update_data},
)
except Exception as err:
@@ -1971,6 +2122,10 @@ class UpdateToolConfig(Resource):
)
@api.doc(description="Update the configuration of a tool")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "config"]
missing_fields = check_required_fields(data, required_fields)
@@ -1979,7 +2134,7 @@ class UpdateToolConfig(Resource):
try:
user_tools_collection.update_one(
- {"_id": ObjectId(data["id"])},
+ {"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": data["config"]}},
)
except Exception as err:
@@ -2006,6 +2161,10 @@ class UpdateToolActions(Resource):
)
@api.doc(description="Update the actions of a tool")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "actions"]
missing_fields = check_required_fields(data, required_fields)
@@ -2014,7 +2173,7 @@ class UpdateToolActions(Resource):
try:
user_tools_collection.update_one(
- {"_id": ObjectId(data["id"])},
+ {"_id": ObjectId(data["id"]), "user": user},
{"$set": {"actions": data["actions"]}},
)
except Exception as err:
@@ -2039,6 +2198,10 @@ class UpdateToolStatus(Resource):
)
@api.doc(description="Update the status of a tool")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "status"]
missing_fields = check_required_fields(data, required_fields)
@@ -2047,7 +2210,7 @@ class UpdateToolStatus(Resource):
try:
user_tools_collection.update_one(
- {"_id": ObjectId(data["id"])},
+ {"_id": ObjectId(data["id"]), "user": user},
{"$set": {"status": data["status"]}},
)
except Exception as err:
@@ -2067,6 +2230,10 @@ class DeleteTool(Resource):
)
@api.doc(description="Delete a tool by ID")
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
@@ -2074,7 +2241,9 @@ class DeleteTool(Resource):
return missing_fields
try:
- result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])})
+ result = user_tools_collection.delete_one(
+ {"_id": ObjectId(data["id"]), "user": user}
+ )
if result.deleted_count == 0:
return {"success": False, "message": "Tool not found"}, 404
except Exception as err:
@@ -2084,21 +2253,6 @@ class DeleteTool(Resource):
return {"success": True}, 200
-def get_vector_store(source_id):
- """
- Get the Vector Store
- Args:
- source_id (str): source id of the document
- """
-
- store = VectorCreator.create_vectorstore(
- settings.VECTOR_STORE,
- source_id=source_id,
- embeddings_key=os.getenv("EMBEDDINGS_KEY"),
- )
- return store
-
-
@user_ns.route("/api/get_chunks")
class GetChunks(Resource):
@api.doc(
@@ -2106,6 +2260,10 @@ class GetChunks(Resource):
params={"id": "The document ID"},
)
def get(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
doc_id = request.args.get("id")
page = int(request.args.get("page", 1))
per_page = int(request.args.get("per_page", 10))
@@ -2113,6 +2271,12 @@ class GetChunks(Resource):
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
+ doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
+ if not doc:
+ return make_response(
+ jsonify({"error": "Document not found or access denied"}), 404
+ )
+
try:
store = get_vector_store(doc_id)
chunks = store.get_chunks()
@@ -2157,6 +2321,10 @@ class AddChunk(Resource):
description="Adds a new chunk to the document",
)
def post(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "text"]
missing_fields = check_required_fields(data, required_fields)
@@ -2170,6 +2338,12 @@ class AddChunk(Resource):
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
+ doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
+ if not doc:
+ return make_response(
+ jsonify({"error": "Document not found or access denied"}), 404
+ )
+
try:
store = get_vector_store(doc_id)
chunk_id = store.add_chunk(text, metadata)
@@ -2189,12 +2363,22 @@ class DeleteChunk(Resource):
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
)
def delete(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
doc_id = request.args.get("id")
chunk_id = request.args.get("chunk_id")
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
+ doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
+ if not doc:
+ return make_response(
+ jsonify({"error": "Document not found or access denied"}), 404
+ )
+
try:
store = get_vector_store(doc_id)
deleted = store.delete_chunk(chunk_id)
@@ -2236,6 +2420,10 @@ class UpdateChunk(Resource):
description="Updates an existing chunk in the document.",
)
def put(self):
+ decoded_token = request.decoded_token
+ if not decoded_token:
+ return make_response(jsonify({"success": False}), 401)
+ user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "chunk_id"]
missing_fields = check_required_fields(data, required_fields)
@@ -2250,6 +2438,12 @@ class UpdateChunk(Resource):
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
+ doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
+ if not doc:
+ return make_response(
+ jsonify({"error": "Document not found or access denied"}), 404
+ )
+
try:
store = get_vector_store(doc_id)
chunks = store.get_chunks()
diff --git a/application/app.py b/application/app.py
index 4eb40331..7ca0ac2b 100644
--- a/application/app.py
+++ b/application/app.py
@@ -1,20 +1,28 @@
+import os
import platform
+import uuid
import dotenv
-from flask import Flask, redirect, request
+from flask import Flask, jsonify, redirect, request
+from jose import jwt
+
+from application.auth import handle_auth
+
from application.core.logging_config import setup_logging
+
setup_logging()
-from application.api.answer.routes import answer # noqa: E402
-from application.api.internal.routes import internal # noqa: E402
-from application.api.user.routes import user # noqa: E402
-from application.celery_init import celery # noqa: E402
-from application.core.settings import settings # noqa: E402
-from application.extensions import api # noqa: E402
+from application.api.answer.routes import answer # noqa: E402
+from application.api.internal.routes import internal # noqa: E402
+from application.api.user.routes import user # noqa: E402
+from application.celery_init import celery # noqa: E402
+from application.core.settings import settings # noqa: E402
+from application.extensions import api # noqa: E402
if platform.system() == "Windows":
import pathlib
+
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
@@ -32,6 +40,25 @@ app.config.update(
celery.config_from_object("application.celeryconfig")
api.init_app(app)
+if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECRET_KEY:
+ key_file = ".jwt_secret_key"
+ try:
+ with open(key_file, "r") as f:
+ settings.JWT_SECRET_KEY = f.read().strip()
+ except FileNotFoundError:
+ new_key = os.urandom(32).hex()
+ with open(key_file, "w") as f:
+ f.write(new_key)
+ settings.JWT_SECRET_KEY = new_key
+ except Exception as e:
+ raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
+
+SIMPLE_JWT_TOKEN = None
+if settings.AUTH_TYPE == "simple_jwt":
+ payload = {"sub": "local"}
+ SIMPLE_JWT_TOKEN = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm="HS256")
+ print(f"Generated Simple JWT Token: {SIMPLE_JWT_TOKEN}")
+
@app.route("/")
def home():
@@ -41,11 +68,47 @@ def home():
return "Welcome to DocsGPT Backend!"
+@app.route("/api/config")
+def get_config():
+ response = {
+ "auth_type": settings.AUTH_TYPE,
+ "requires_auth": settings.AUTH_TYPE in ["simple_jwt", "session_jwt"],
+ }
+ return jsonify(response)
+
+
+@app.route("/api/generate_token")
+def generate_token():
+ if settings.AUTH_TYPE == "session_jwt":
+ new_user_id = str(uuid.uuid4())
+ token = jwt.encode(
+ {"sub": new_user_id}, settings.JWT_SECRET_KEY, algorithm="HS256"
+ )
+ return jsonify({"token": token})
+ return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
+
+
+@app.before_request
+def authenticate_request():
+ if request.method == "OPTIONS":
+ return "", 200
+
+ decoded_token = handle_auth(request)
+ if not decoded_token:
+ request.decoded_token = None
+ elif "error" in decoded_token:
+ return jsonify(decoded_token), 401
+ else:
+ request.decoded_token = decoded_token
+
+
@app.after_request
def after_request(response):
response.headers.add("Access-Control-Allow-Origin", "*")
- response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization")
- response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS")
+ response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
+ response.headers.add(
+ "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
+ )
return response
diff --git a/application/auth.py b/application/auth.py
new file mode 100644
index 00000000..78926c45
--- /dev/null
+++ b/application/auth.py
@@ -0,0 +1,28 @@
+from jose import jwt
+
+from application.core.settings import settings
+
+
+def handle_auth(request, data={}):
+ if settings.AUTH_TYPE in ["simple_jwt", "session_jwt"]:
+ jwt_token = request.headers.get("Authorization")
+ if not jwt_token:
+ return None
+
+ jwt_token = jwt_token.replace("Bearer ", "")
+
+ try:
+ decoded_token = jwt.decode(
+ jwt_token,
+ settings.JWT_SECRET_KEY,
+ algorithms=["HS256"],
+ options={"verify_exp": False},
+ )
+ return decoded_token
+ except Exception as e:
+ return {
+ "message": f"Authentication error: {str(e)}",
+ "error": "invalid_token",
+ }
+ else:
+ return {"sub": "local"}
diff --git a/application/core/settings.py b/application/core/settings.py
index 04d7bbea..74bffe53 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -10,6 +10,7 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
+ AUTH_TYPE: Optional[str] = None
LLM_NAME: str = "docsgpt"
MODEL_NAME: Optional[str] = (
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
@@ -98,6 +99,8 @@ class Settings(BaseSettings):
FLASK_DEBUG_MODE: bool = False
+ JWT_SECRET_KEY: str = ""
+
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
diff --git a/application/llm/base.py b/application/llm/base.py
index e687e567..0fce208c 100644
--- a/application/llm/base.py
+++ b/application/llm/base.py
@@ -5,7 +5,8 @@ from application.usage import gen_token_usage, stream_token_usage
class BaseLLM(ABC):
- def __init__(self):
+ def __init__(self, decoded_token=None):
+ self.decoded_token = decoded_token
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
def _apply_decorator(self, method, decorators, *args, **kwargs):
diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py
index 9f1305ba..3ed23854 100644
--- a/application/llm/llm_creator.py
+++ b/application/llm/llm_creator.py
@@ -9,6 +9,7 @@ from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM
from application.llm.novita import NovitaLLM
+
class LLMCreator:
llms = {
"openai": OpenAILLM,
@@ -21,12 +22,14 @@ class LLMCreator:
"premai": PremAILLM,
"groq": GroqLLM,
"google": GoogleLLM,
- "novita": NovitaLLM
+ "novita": NovitaLLM,
}
@classmethod
- def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
+ def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
- return llm_class(api_key, user_api_key, *args, **kwargs)
+ return llm_class(
+ api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
+ )
diff --git a/application/requirements.txt b/application/requirements.txt
index 713ae2e3..5323fe85 100644
--- a/application/requirements.txt
+++ b/application/requirements.txt
@@ -69,6 +69,7 @@ pymongo==4.10.1
pypdf==5.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
+python-jose==3.4.0
python-pptx==1.0.2
qdrant-client==1.13.2
redis==5.2.1
diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py
index 08b16bc0..ed490734 100644
--- a/application/retriever/brave_search.py
+++ b/application/retriever/brave_search.py
@@ -17,6 +17,7 @@ class BraveRetSearch(BaseRetriever):
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
+ decoded_token=None,
):
self.question = question
self.source = source
@@ -35,6 +36,7 @@ class BraveRetSearch(BaseRetriever):
)
)
self.user_api_key = user_api_key
+ self.decoded_token = decoded_token
def _get_data(self):
if self.chunks == 0:
@@ -81,7 +83,10 @@ class BraveRetSearch(BaseRetriever):
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
- settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
+ settings.LLM_NAME,
+ api_key=settings.API_KEY,
+ user_api_key=self.user_api_key,
+ decoded_token=self.decoded_token,
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
@@ -100,5 +105,5 @@ class BraveRetSearch(BaseRetriever):
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
- "user_api_key": self.user_api_key
+ "user_api_key": self.user_api_key,
}
diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py
index 03f17f44..08771337 100644
--- a/application/retriever/classic_rag.py
+++ b/application/retriever/classic_rag.py
@@ -17,6 +17,7 @@ class ClassicRAG(BaseRetriever):
user_api_key=None,
llm_name=settings.LLM_NAME,
api_key=settings.API_KEY,
+ decoded_token=None,
):
self.original_question = ""
self.chat_history = chat_history if chat_history is not None else []
@@ -37,10 +38,14 @@ class ClassicRAG(BaseRetriever):
self.llm_name = llm_name
self.api_key = api_key
self.llm = LLMCreator.create_llm(
- self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key
+ self.llm_name,
+ api_key=self.api_key,
+ user_api_key=self.user_api_key,
+ decoded_token=decoded_token,
)
self.question = self._rephrase_query()
self.vectorstore = source["active_docs"] if "active_docs" in source else None
+ self.decoded_token = decoded_token
def _rephrase_query(self):
if (
diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py
index c6386410..9ce73995 100644
--- a/application/retriever/duckduck_search.py
+++ b/application/retriever/duckduck_search.py
@@ -17,6 +17,7 @@ class DuckDuckSearch(BaseRetriever):
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
+ decoded_token=None,
):
self.question = question
self.source = source
@@ -35,6 +36,7 @@ class DuckDuckSearch(BaseRetriever):
)
)
self.user_api_key = user_api_key
+ self.decoded_token = decoded_token
def _parse_lang_string(self, input_string):
result = []
@@ -88,17 +90,20 @@ class DuckDuckSearch(BaseRetriever):
for doc in docs:
yield {"source": doc}
- if len(self.chat_history) > 0:
+ if len(self.chat_history) > 0:
for i in self.chat_history:
- if "prompt" in i and "response" in i:
- messages_combine.append({"role": "user", "content": i["prompt"]})
- messages_combine.append(
- {"role": "assistant", "content": i["response"]}
- )
+ if "prompt" in i and "response" in i:
+ messages_combine.append({"role": "user", "content": i["prompt"]})
+ messages_combine.append(
+ {"role": "assistant", "content": i["response"]}
+ )
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
- settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
+ settings.LLM_NAME,
+ api_key=settings.API_KEY,
+ user_api_key=self.user_api_key,
+ decoded_token=self.decoded_token,
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
@@ -107,7 +112,7 @@ class DuckDuckSearch(BaseRetriever):
def search(self):
return self._get_data()
-
+
def get_params(self):
return {
"question": self.question,
@@ -117,5 +122,5 @@ class DuckDuckSearch(BaseRetriever):
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
- "user_api_key": self.user_api_key
+ "user_api_key": self.user_api_key,
}
diff --git a/application/usage.py b/application/usage.py
index a18a3848..85328c1f 100644
--- a/application/usage.py
+++ b/application/usage.py
@@ -9,10 +9,15 @@ db = mongo["docsgpt"]
usage_collection = db["token_usage"]
-def update_token_usage(user_api_key, token_usage):
+def update_token_usage(decoded_token, user_api_key, token_usage):
if "pytest" in sys.modules:
return
+ if decoded_token:
+ user_id = decoded_token["sub"]
+ else:
+ user_id = None
usage_data = {
+ "user_id": user_id,
"api_key": user_api_key,
"prompt_tokens": token_usage["prompt_tokens"],
"generated_tokens": token_usage["generated_tokens"],
@@ -35,7 +40,7 @@ def gen_token_usage(func):
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
result
)
- update_token_usage(self.user_api_key, self.token_usage)
+ update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
return result
return wrapper
@@ -54,6 +59,6 @@ def stream_token_usage(func):
yield r
for line in batch:
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
- update_token_usage(self.user_api_key, self.token_usage)
+ update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
return wrapper
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index ba0a4bd7..64c4c486 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -1,15 +1,30 @@
-import { Routes, Route } from 'react-router-dom';
-import Navigation from './Navigation';
-import Conversation from './conversation/Conversation';
-import About from './About';
-import PageNotFound from './PageNotFound';
-import { useMediaQuery } from './hooks';
-import { useState } from 'react';
-import Setting from './settings';
import './locale/i18n';
-import { Outlet } from 'react-router-dom';
+
+import { useState } from 'react';
+import { Outlet, Route, Routes } from 'react-router-dom';
+
+import About from './About';
+import Spinner from './components/Spinner';
+import Conversation from './conversation/Conversation';
import { SharedConversation } from './conversation/SharedConversation';
-import { useDarkTheme } from './hooks';
+import { useDarkTheme, useMediaQuery } from './hooks';
+import useTokenAuth from './hooks/useTokenAuth';
+import Navigation from './Navigation';
+import PageNotFound from './PageNotFound';
+import Setting from './settings';
+
+function AuthWrapper({ children }: { children: React.ReactNode }) {
+ const { isAuthLoading } = useTokenAuth();
+
+ if (isAuthLoading) {
+ return (
+
+
+
+ );
+ }
+ return <>{children}>;
+}
function MainLayout() {
const { isMobile } = useMediaQuery();
@@ -39,7 +54,13 @@ export default function App() {
return (
- }>
+
+
+
+ }
+ >
} />
} />
} />
diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx
index 49795b41..9e1889aa 100644
--- a/frontend/src/Navigation.tsx
+++ b/frontend/src/Navigation.tsx
@@ -2,28 +2,35 @@ import { useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux';
import { NavLink, useNavigate } from 'react-router-dom';
+
import conversationService from './api/services/conversationService';
import userService from './api/services/userService';
import Add from './assets/add.svg';
-import openNewChat from './assets/openNewChat.svg';
-import Hamburger from './assets/hamburger.svg';
import DocsGPT3 from './assets/cute_docsgpt3.svg';
import Discord from './assets/discord.svg';
import Expand from './assets/expand.svg';
import Github from './assets/github.svg';
+import Hamburger from './assets/hamburger.svg';
+import openNewChat from './assets/openNewChat.svg';
import SettingGear from './assets/settingGear.svg';
+import SpinnerDark from './assets/spinner-dark.svg';
+import Spinner from './assets/spinner.svg';
import Twitter from './assets/TwitterX.svg';
import UploadIcon from './assets/upload.svg';
+import Help from './components/Help';
import SourceDropdown from './components/SourceDropdown';
import {
+ handleAbort,
+ selectQueries,
setConversation,
updateConversationId,
- handleAbort,
} from './conversation/conversationSlice';
import ConversationTile from './conversation/ConversationTile';
import { useDarkTheme, useMediaQuery } from './hooks';
import useDefaultDocument from './hooks/useDefaultDocument';
+import useTokenAuth from './hooks/useTokenAuth';
import DeleteConvModal from './modals/DeleteConvModal';
+import JWTModal from './modals/JWTModal';
import { ActiveState, Doc } from './models/misc';
import { getConversations, getDocs } from './preferences/preferenceApi';
import {
@@ -31,20 +38,17 @@ import {
selectConversationId,
selectConversations,
selectModalStateDeleteConv,
+ selectPaginatedDocuments,
selectSelectedDocs,
selectSourceDocs,
- selectPaginatedDocuments,
+ selectToken,
setConversations,
setModalStateDeleteConv,
+ setPaginatedDocuments,
setSelectedDocs,
setSourceDocs,
- setPaginatedDocuments,
} from './preferences/preferenceSlice';
-import Spinner from './assets/spinner.svg';
-import SpinnerDark from './assets/spinner-dark.svg';
-import { selectQueries } from './conversation/conversationSlice';
import Upload from './upload/Upload';
-import Help from './components/Help';
interface NavigationProps {
navOpen: boolean;
@@ -53,6 +57,7 @@ interface NavigationProps {
export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
const dispatch = useDispatch();
+ const token = useSelector(selectToken);
const queries = useSelector(selectQueries);
const docs = useSelector(selectSourceDocs);
const selectedDocs = useSelector(selectSelectedDocs);
@@ -68,6 +73,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
const { t } = useTranslation();
const isApiKeySet = useSelector(selectApiKeyStatus);
+ const { showTokenModal, handleTokenSubmit } = useTokenAuth();
+
const [uploadModalState, setUploadModalState] =
useState('INACTIVE');
@@ -86,7 +93,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
async function fetchConversations() {
dispatch(setConversations({ ...conversations, loading: true }));
- return await getConversations()
+ return await getConversations(token)
.then((fetchedConversations) => {
dispatch(setConversations(fetchedConversations));
})
@@ -99,7 +106,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
const handleDeleteAllConversations = () => {
setIsDeletingConversation(true);
conversationService
- .deleteAll()
+ .deleteAll(token)
.then(() => {
fetchConversations();
})
@@ -109,7 +116,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
const handleDeleteConversation = (id: string) => {
setIsDeletingConversation(true);
conversationService
- .delete(id, {})
+ .delete(id, {}, token)
.then(() => {
fetchConversations();
resetConversation();
@@ -119,9 +126,9 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
const handleDeleteClick = (doc: Doc) => {
userService
- .deletePath(doc.id ?? '')
+ .deletePath(doc.id ?? '', token)
.then(() => {
- return getDocs();
+ return getDocs(token);
})
.then((updatedDocs) => {
dispatch(setSourceDocs(updatedDocs));
@@ -145,7 +152,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
const handleConversationClick = (index: string) => {
conversationService
- .getConversation(index)
+ .getConversation(index, token)
.then((response) => response.json())
.then((data) => {
navigate('/');
@@ -177,7 +184,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
id: string;
}) {
await conversationService
- .update(updatedConversation)
+ .update(updatedConversation, token)
.then((response) => response.json())
.then((data) => {
if (data) {
@@ -197,8 +204,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
useEffect(() => {
setNavOpen(!isMobile);
}, [isMobile]);
- useDefaultDocument();
+ useDefaultDocument();
return (
<>
{!navOpen && (
@@ -472,6 +479,10 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
close={() => setUploadModalState('INACTIVE')}
>
)}
+
>
);
}
diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts
index 21699721..3db613fc 100644
--- a/frontend/src/api/client.ts
+++ b/frontend/src/api/client.ts
@@ -4,14 +4,24 @@ const defaultHeaders = {
'Content-Type': 'application/json',
};
+const getHeaders = (token: string | null, customHeaders = {}): HeadersInit => {
+ return {
+ ...defaultHeaders,
+ ...(token ? { Authorization: `Bearer ${token}` } : {}),
+ ...customHeaders,
+ };
+};
+
const apiClient = {
- get: (url: string, headers = {}, signal?: AbortSignal): Promise =>
+ get: (
+ url: string,
+ token: string | null,
+ headers = {},
+ signal?: AbortSignal,
+ ): Promise =>
fetch(`${baseURL}${url}`, {
method: 'GET',
- headers: {
- ...defaultHeaders,
- ...headers,
- },
+ headers: getHeaders(token, headers),
signal,
}).then((response) => {
return response;
@@ -20,15 +30,13 @@ const apiClient = {
post: (
url: string,
data: any,
+ token: string | null,
headers = {},
signal?: AbortSignal,
): Promise =>
fetch(`${baseURL}${url}`, {
method: 'POST',
- headers: {
- ...defaultHeaders,
- ...headers,
- },
+ headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
@@ -38,28 +46,28 @@ const apiClient = {
put: (
url: string,
data: any,
+ token: string | null,
headers = {},
signal?: AbortSignal,
): Promise =>
fetch(`${baseURL}${url}`, {
method: 'PUT',
- headers: {
- ...defaultHeaders,
- ...headers,
- },
+ headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
- delete: (url: string, headers = {}, signal?: AbortSignal): Promise =>
+ delete: (
+ url: string,
+ token: string | null,
+ headers = {},
+ signal?: AbortSignal,
+ ): Promise =>
fetch(`${baseURL}${url}`, {
method: 'DELETE',
- headers: {
- ...defaultHeaders,
- ...headers,
- },
+ headers: getHeaders(token, headers),
signal,
}).then((response) => {
return response;
diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts
index 9bf659de..0d574f89 100644
--- a/frontend/src/api/endpoints.ts
+++ b/frontend/src/api/endpoints.ts
@@ -1,5 +1,7 @@
const endpoints = {
USER: {
+ CONFIG: '/api/config',
+ NEW_TOKEN: '/api/generate_token',
DOCS: '/api/sources',
DOCS_CHECK: '/api/docs_check',
DOCS_PAGINATED: '/api/sources/paginated',
diff --git a/frontend/src/api/services/conversationService.ts b/frontend/src/api/services/conversationService.ts
index aaf703de..853a6863 100644
--- a/frontend/src/api/services/conversationService.ts
+++ b/frontend/src/api/services/conversationService.ts
@@ -2,31 +2,58 @@ import apiClient from '../client';
import endpoints from '../endpoints';
const conversationService = {
- answer: (data: any, signal: AbortSignal): Promise =>
- apiClient.post(endpoints.CONVERSATION.ANSWER, data, {}, signal),
- answerStream: (data: any, signal: AbortSignal): Promise =>
- apiClient.post(endpoints.CONVERSATION.ANSWER_STREAMING, data, {}, signal),
- search: (data: any): Promise =>
- apiClient.post(endpoints.CONVERSATION.SEARCH, data),
- feedback: (data: any): Promise =>
- apiClient.post(endpoints.CONVERSATION.FEEDBACK, data),
- getConversation: (id: string): Promise =>
- apiClient.get(endpoints.CONVERSATION.CONVERSATION(id)),
- getConversations: (): Promise =>
- apiClient.get(endpoints.CONVERSATION.CONVERSATIONS),
- shareConversation: (isPromptable: boolean, data: any): Promise =>
+ answer: (
+ data: any,
+ token: string | null,
+ signal: AbortSignal,
+ ): Promise =>
+ apiClient.post(endpoints.CONVERSATION.ANSWER, data, token, {}, signal),
+ answerStream: (
+ data: any,
+ token: string | null,
+ signal: AbortSignal,
+ ): Promise =>
+ apiClient.post(
+ endpoints.CONVERSATION.ANSWER_STREAMING,
+ data,
+ token,
+ {},
+ signal,
+ ),
+ search: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.CONVERSATION.SEARCH, data, token, {}),
+ feedback: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.CONVERSATION.FEEDBACK, data, token, {}),
+ getConversation: (id: string, token: string | null): Promise =>
+ apiClient.get(endpoints.CONVERSATION.CONVERSATION(id), token, {}),
+ getConversations: (token: string | null): Promise =>
+ apiClient.get(endpoints.CONVERSATION.CONVERSATIONS, token, {}),
+ shareConversation: (
+ isPromptable: boolean,
+ data: any,
+ token: string | null,
+ ): Promise =>
apiClient.post(
endpoints.CONVERSATION.SHARE_CONVERSATION(isPromptable),
data,
+ token,
+ {},
),
- getSharedConversation: (identifier: string): Promise =>
- apiClient.get(endpoints.CONVERSATION.SHARED_CONVERSATION(identifier)),
- delete: (id: string, data: any): Promise =>
- apiClient.post(endpoints.CONVERSATION.DELETE(id), data),
- deleteAll: (): Promise =>
- apiClient.get(endpoints.CONVERSATION.DELETE_ALL),
- update: (data: any): Promise =>
- apiClient.post(endpoints.CONVERSATION.UPDATE, data),
+ getSharedConversation: (
+ identifier: string,
+ token: string | null,
+ ): Promise =>
+ apiClient.get(
+ endpoints.CONVERSATION.SHARED_CONVERSATION(identifier),
+ token,
+ {},
+ ),
+ delete: (id: string, data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.CONVERSATION.DELETE(id), data, token, {}),
+ deleteAll: (token: string | null): Promise =>
+ apiClient.get(endpoints.CONVERSATION.DELETE_ALL, token, {}),
+ update: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.CONVERSATION.UPDATE, data, token, {}),
};
export default conversationService;
diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts
index e7f367f1..13083677 100644
--- a/frontend/src/api/services/userService.ts
+++ b/frontend/src/api/services/userService.ts
@@ -2,63 +2,74 @@ import apiClient from '../client';
import endpoints from '../endpoints';
const userService = {
- getDocs: (): Promise => apiClient.get(`${endpoints.USER.DOCS}`),
- getDocsWithPagination: (query: string): Promise =>
- apiClient.get(`${endpoints.USER.DOCS_PAGINATED}?${query}`),
- checkDocs: (data: any): Promise =>
- apiClient.post(endpoints.USER.DOCS_CHECK, data),
- getAPIKeys: (): Promise => apiClient.get(endpoints.USER.API_KEYS),
- createAPIKey: (data: any): Promise =>
- apiClient.post(endpoints.USER.CREATE_API_KEY, data),
- deleteAPIKey: (data: any): Promise =>
- apiClient.post(endpoints.USER.DELETE_API_KEY, data),
- getPrompts: (): Promise => apiClient.get(endpoints.USER.PROMPTS),
- createPrompt: (data: any): Promise =>
- apiClient.post(endpoints.USER.CREATE_PROMPT, data),
- deletePrompt: (data: any): Promise =>
- apiClient.post(endpoints.USER.DELETE_PROMPT, data),
- updatePrompt: (data: any): Promise =>
- apiClient.post(endpoints.USER.UPDATE_PROMPT, data),
- getSinglePrompt: (id: string): Promise =>
- apiClient.get(endpoints.USER.SINGLE_PROMPT(id)),
- deletePath: (docPath: string): Promise =>
- apiClient.get(endpoints.USER.DELETE_PATH(docPath)),
- getTaskStatus: (task_id: string): Promise =>
- apiClient.get(endpoints.USER.TASK_STATUS(task_id)),
- getMessageAnalytics: (data: any): Promise =>
- apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data),
- getTokenAnalytics: (data: any): Promise =>
- apiClient.post(endpoints.USER.TOKEN_ANALYTICS, data),
- getFeedbackAnalytics: (data: any): Promise =>
- apiClient.post(endpoints.USER.FEEDBACK_ANALYTICS, data),
- getLogs: (data: any): Promise =>
- apiClient.post(endpoints.USER.LOGS, data),
- manageSync: (data: any): Promise =>
- apiClient.post(endpoints.USER.MANAGE_SYNC, data),
- getAvailableTools: (): Promise =>
- apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS),
- getUserTools: (): Promise =>
- apiClient.get(endpoints.USER.GET_USER_TOOLS),
- createTool: (data: any): Promise =>
- apiClient.post(endpoints.USER.CREATE_TOOL, data),
- updateToolStatus: (data: any): Promise =>
- apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data),
- updateTool: (data: any): Promise =>
- apiClient.post(endpoints.USER.UPDATE_TOOL, data),
- deleteTool: (data: any): Promise =>
- apiClient.post(endpoints.USER.DELETE_TOOL, data),
+ getConfig: (): Promise => apiClient.get(endpoints.USER.CONFIG, null),
+ getNewToken: (): Promise =>
+ apiClient.get(endpoints.USER.NEW_TOKEN, null),
+ getDocs: (token: string | null): Promise =>
+ apiClient.get(`${endpoints.USER.DOCS}`, token),
+ getDocsWithPagination: (query: string, token: string | null): Promise =>
+ apiClient.get(`${endpoints.USER.DOCS_PAGINATED}?${query}`, token),
+ checkDocs: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.DOCS_CHECK, data, token),
+ getAPIKeys: (token: string | null): Promise =>
+ apiClient.get(endpoints.USER.API_KEYS, token),
+ createAPIKey: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.CREATE_API_KEY, data, token),
+ deleteAPIKey: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.DELETE_API_KEY, data, token),
+ getPrompts: (token: string | null): Promise =>
+ apiClient.get(endpoints.USER.PROMPTS, token),
+ createPrompt: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.CREATE_PROMPT, data, token),
+ deletePrompt: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.DELETE_PROMPT, data, token),
+ updatePrompt: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.UPDATE_PROMPT, data, token),
+ getSinglePrompt: (id: string, token: string | null): Promise =>
+ apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token),
+ deletePath: (docPath: string, token: string | null): Promise =>
+ apiClient.get(endpoints.USER.DELETE_PATH(docPath), token),
+ getTaskStatus: (task_id: string, token: string | null): Promise =>
+ apiClient.get(endpoints.USER.TASK_STATUS(task_id), token),
+ getMessageAnalytics: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token),
+ getTokenAnalytics: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.TOKEN_ANALYTICS, data, token),
+ getFeedbackAnalytics: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.FEEDBACK_ANALYTICS, data, token),
+ getLogs: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.LOGS, data, token),
+ manageSync: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.MANAGE_SYNC, data, token),
+ getAvailableTools: (token: string | null): Promise =>
+ apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS, token),
+ getUserTools: (token: string | null): Promise =>
+ apiClient.get(endpoints.USER.GET_USER_TOOLS, token),
+ createTool: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.CREATE_TOOL, data, token),
+ updateToolStatus: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data, token),
+ updateTool: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.UPDATE_TOOL, data, token),
+ deleteTool: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.DELETE_TOOL, data, token),
getDocumentChunks: (
docId: string,
page: number,
perPage: number,
+ token: string | null,
): Promise =>
- apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage)),
- addChunk: (data: any): Promise =>
- apiClient.post(endpoints.USER.ADD_CHUNK, data),
- deleteChunk: (docId: string, chunkId: string): Promise =>
- apiClient.delete(endpoints.USER.DELETE_CHUNK(docId, chunkId)),
- updateChunk: (data: any): Promise =>
- apiClient.put(endpoints.USER.UPDATE_CHUNK, data),
+ apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage), token),
+ addChunk: (data: any, token: string | null): Promise =>
+ apiClient.post(endpoints.USER.ADD_CHUNK, data, token),
+ deleteChunk: (
+ docId: string,
+ chunkId: string,
+ token: string | null,
+ ): Promise =>
+ apiClient.delete(endpoints.USER.DELETE_CHUNK(docId, chunkId), token),
+ updateChunk: (data: any, token: string | null): Promise =>
+ apiClient.put(endpoints.USER.UPDATE_CHUNK, data, token),
};
export default userService;
diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx
index 9f54ddbd..2dd9e773 100644
--- a/frontend/src/conversation/Conversation.tsx
+++ b/frontend/src/conversation/Conversation.tsx
@@ -7,7 +7,10 @@ import newChatIcon from '../assets/openNewChat.svg';
import ShareIcon from '../assets/share.svg';
import { useMediaQuery } from '../hooks';
import { ShareConversationModal } from '../modals/ShareConversationModal';
-import { selectConversationId } from '../preferences/preferenceSlice';
+import {
+ selectConversationId,
+ selectToken,
+} from '../preferences/preferenceSlice';
import { AppDispatch } from '../store';
import { handleSendFeedback } from './conversationHandlers';
import { FEEDBACK, Query } from './conversationModels';
@@ -27,6 +30,7 @@ import ConversationMessages from './ConversationMessages';
import MessageInput from '../components/MessageInput';
export default function Conversation() {
+ const token = useSelector(selectToken);
const queries = useSelector(selectQueries);
const status = useSelector(selectStatus);
const conversationId = useSelector(selectConversationId);
@@ -118,6 +122,7 @@ export default function Conversation() {
feedback,
conversationId as string,
index,
+ token,
).catch(() =>
handleSendFeedback(
query.prompt,
@@ -125,6 +130,7 @@ export default function Conversation() {
feedback,
conversationId as string,
index,
+ token,
).catch(() =>
dispatch(updateQuery({ index, query: { feedback: prevFeedback } })),
),
diff --git a/frontend/src/conversation/SharedConversation.tsx b/frontend/src/conversation/SharedConversation.tsx
index 993556e4..be822805 100644
--- a/frontend/src/conversation/SharedConversation.tsx
+++ b/frontend/src/conversation/SharedConversation.tsx
@@ -1,35 +1,34 @@
import { useEffect, useState } from 'react';
+import { Helmet } from 'react-helmet';
import { useTranslation } from 'react-i18next';
+import { useDispatch, useSelector } from 'react-redux';
import { useNavigate, useParams } from 'react-router-dom';
-import ConversationMessages from './ConversationMessages';
-import MessageInput from '../components/MessageInput';
+
import conversationService from '../api/services/conversationService';
+import MessageInput from '../components/MessageInput';
+import { selectToken } from '../preferences/preferenceSlice';
+import { AppDispatch } from '../store';
+import { formatDate } from '../utils/dateTimeUtils';
+import ConversationMessages from './ConversationMessages';
import {
- selectClientAPIKey,
- setClientApiKey,
- updateQuery,
addQuery,
fetchSharedAnswer,
- selectStatus,
-} from './sharedConversationSlice';
-import { setIdentifier, setFetchedData } from './sharedConversationSlice';
-
-import { useDispatch } from 'react-redux';
-import { AppDispatch } from '../store';
-
-import {
+ selectClientAPIKey,
selectDate,
- selectTitle,
selectQueries,
+ selectStatus,
+ selectTitle,
+ setClientApiKey,
+ setFetchedData,
+ setIdentifier,
+ updateQuery,
} from './sharedConversationSlice';
-import { useSelector } from 'react-redux';
-import { Helmet } from 'react-helmet';
-import { formatDate } from '../utils/dateTimeUtils';
export const SharedConversation = () => {
const navigate = useNavigate();
const { identifier } = useParams(); //identifier is a uuid, not conversationId
+ const token = useSelector(selectToken);
const queries = useSelector(selectQueries);
const title = useSelector(selectTitle);
const date = useSelector(selectDate);
@@ -56,7 +55,7 @@ export const SharedConversation = () => {
const fetchQueries = () => {
identifier &&
conversationService
- .getSharedConversation(identifier || '')
+ .getSharedConversation(identifier || '', token)
.then((res) => {
if (res.status === 404 || res.status === 400)
navigate('/pagenotfound');
diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts
index 0b54a366..88771fc5 100644
--- a/frontend/src/conversation/conversationHandlers.ts
+++ b/frontend/src/conversation/conversationHandlers.ts
@@ -6,6 +6,7 @@ import { ToolCallsType } from './types';
export function handleFetchAnswer(
question: string,
signal: AbortSignal,
+ token: string | null,
selectedDocs: Doc | null,
history: Array = [],
conversationId: string | null,
@@ -52,7 +53,7 @@ export function handleFetchAnswer(
}
payload.retriever = selectedDocs?.retriever as string;
return conversationService
- .answer(payload, signal)
+ .answer(payload, token, signal)
.then((response) => {
if (response.ok) {
return response.json();
@@ -76,6 +77,7 @@ export function handleFetchAnswer(
export function handleFetchAnswerSteaming(
question: string,
signal: AbortSignal,
+ token: string | null,
selectedDocs: Doc | null,
history: Array = [],
conversationId: string | null,
@@ -109,7 +111,7 @@ export function handleFetchAnswerSteaming(
return new Promise((resolve, reject) => {
conversationService
- .answerStream(payload, signal)
+ .answerStream(payload, token, signal)
.then((response) => {
if (!response.body) throw Error('No response body');
@@ -160,6 +162,7 @@ export function handleFetchAnswerSteaming(
export function handleSearch(
question: string,
+ token: string | null,
selectedDocs: Doc | null,
conversation_id: string | null,
history: Array = [],
@@ -185,7 +188,7 @@ export function handleSearch(
payload.active_docs = selectedDocs.id as string;
payload.retriever = selectedDocs?.retriever as string;
return conversationService
- .search(payload)
+ .search(payload, token)
.then((response) => response.json())
.then((data) => {
return data;
@@ -206,11 +209,14 @@ export function handleSearchViaApiKey(
};
});
return conversationService
- .search({
- question: question,
- history: JSON.stringify(history),
- api_key: api_key,
- })
+ .search(
+ {
+ question: question,
+ history: JSON.stringify(history),
+ api_key: api_key,
+ },
+ null,
+ )
.then((response) => response.json())
.then((data) => {
return data;
@@ -224,15 +230,19 @@ export function handleSendFeedback(
feedback: FEEDBACK,
conversation_id: string,
prompt_index: number,
+ token: string | null,
) {
return conversationService
- .feedback({
- question: prompt,
- answer: response,
- feedback: feedback,
- conversation_id: conversation_id,
- question_index: prompt_index,
- })
+ .feedback(
+ {
+ question: prompt,
+ answer: response,
+ feedback: feedback,
+ conversation_id: conversation_id,
+ question_index: prompt_index,
+ },
+ token,
+ )
.then((response) => {
if (response.ok) {
return Promise.resolve();
@@ -265,7 +275,7 @@ export function handleFetchSharedAnswerStreaming( //for shared conversations
save_conversation: false,
};
conversationService
- .answerStream(payload, signal)
+ .answerStream(payload, null, signal)
.then((response) => {
if (!response.body) throw Error('No response body');
@@ -339,6 +349,7 @@ export function handleFetchSharedAnswer(
question: question,
api_key: apiKey,
},
+ null,
signal,
)
.then((response) => {
diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts
index f00eb546..7cd14d5e 100644
--- a/frontend/src/conversation/conversationSlice.ts
+++ b/frontend/src/conversation/conversationSlice.ts
@@ -42,6 +42,7 @@ export const fetchAnswer = createAsyncThunk<
await handleFetchAnswerSteaming(
question,
signal,
+ state.preference.token,
state.preference.selectedDocs!,
state.conversation.queries,
state.conversation.conversationId,
@@ -53,7 +54,7 @@ export const fetchAnswer = createAsyncThunk<
if (data.type === 'end') {
dispatch(conversationSlice.actions.setStatus('idle'));
- getConversations()
+ getConversations(state.preference.token)
.then((fetchedConversations) => {
dispatch(setConversations(fetchedConversations));
})
@@ -114,6 +115,7 @@ export const fetchAnswer = createAsyncThunk<
const answer = await handleFetchAnswer(
question,
signal,
+ state.preference.token,
state.preference.selectedDocs!,
state.conversation.queries,
state.conversation.conversationId,
@@ -150,7 +152,7 @@ export const fetchAnswer = createAsyncThunk<
}),
);
dispatch(conversationSlice.actions.setStatus('idle'));
- getConversations()
+ getConversations(state.preference.token)
.then((fetchedConversations) => {
dispatch(setConversations(fetchedConversations));
})
diff --git a/frontend/src/hooks/useDefaultDocument.ts b/frontend/src/hooks/useDefaultDocument.ts
index 7f4b9812..a2642dc5 100644
--- a/frontend/src/hooks/useDefaultDocument.ts
+++ b/frontend/src/hooks/useDefaultDocument.ts
@@ -1,20 +1,22 @@
import React from 'react';
import { useDispatch, useSelector } from 'react-redux';
-import { getDocs } from '../preferences/preferenceApi';
import { Doc } from '../models/misc';
+import { getDocs } from '../preferences/preferenceApi';
import {
selectSelectedDocs,
+ selectToken,
setSelectedDocs,
setSourceDocs,
} from '../preferences/preferenceSlice';
export default function useDefaultDocument() {
const dispatch = useDispatch();
+ const token = useSelector(selectToken);
const selectedDoc = useSelector(selectSelectedDocs);
const fetchDocs = () => {
- getDocs().then((data) => {
+ getDocs(token).then((data) => {
dispatch(setSourceDocs(data));
if (!selectedDoc)
Array.isArray(data) &&
diff --git a/frontend/src/hooks/useTokenAuth.ts b/frontend/src/hooks/useTokenAuth.ts
new file mode 100644
index 00000000..8f408600
--- /dev/null
+++ b/frontend/src/hooks/useTokenAuth.ts
@@ -0,0 +1,55 @@
+import { useEffect, useRef, useState } from 'react';
+import { useDispatch, useSelector } from 'react-redux';
+
+import userService from '../api/services/userService';
+import { selectToken, setToken } from '../preferences/preferenceSlice';
+
+export default function useAuth() {
+ const dispatch = useDispatch();
+ const token = useSelector(selectToken);
+ const [authType, setAuthType] = useState(null);
+ const [showTokenModal, setShowTokenModal] = useState(false);
+ const [isAuthLoading, setIsAuthLoading] = useState(true);
+ const isGeneratingToken = useRef(false);
+
+ const generateNewToken = async () => {
+ if (isGeneratingToken.current) return;
+ isGeneratingToken.current = true;
+ const response = await userService.getNewToken();
+ const { token: newToken } = await response.json();
+ localStorage.setItem('authToken', newToken);
+ dispatch(setToken(newToken));
+ setIsAuthLoading(false);
+ return newToken;
+ };
+
+ useEffect(() => {
+ const initializeAuth = async () => {
+ try {
+ const configRes = await userService.getConfig();
+ const config = await configRes.json();
+ setAuthType(config.auth_type);
+
+ if (config.auth_type === 'session_jwt' && !token) {
+ await generateNewToken();
+ } else if (config.auth_type === 'simple_jwt' && !token) {
+ setShowTokenModal(true);
+ setIsAuthLoading(false);
+ } else {
+ setIsAuthLoading(false);
+ }
+ } catch (error) {
+ console.error('Auth initialization failed:', error);
+ setIsAuthLoading(false);
+ }
+ };
+ initializeAuth();
+ }, []);
+
+ const handleTokenSubmit = (enteredToken: string) => {
+ localStorage.setItem('authToken', enteredToken);
+ dispatch(setToken(enteredToken));
+ setShowTokenModal(false);
+ };
+ return { authType, showTokenModal, isAuthLoading, token, handleTokenSubmit };
+}
diff --git a/frontend/src/modals/AddToolModal.tsx b/frontend/src/modals/AddToolModal.tsx
index 42b55d69..9885edab 100644
--- a/frontend/src/modals/AddToolModal.tsx
+++ b/frontend/src/modals/AddToolModal.tsx
@@ -1,12 +1,14 @@
import React, { useRef } from 'react';
import { useTranslation } from 'react-i18next';
+import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
+import Spinner from '../components/Spinner';
import { useOutsideAlerter } from '../hooks';
import { ActiveState } from '../models/misc';
+import { selectToken } from '../preferences/preferenceSlice';
import ConfigToolModal from './ConfigToolModal';
import { AvailableToolType } from './types';
-import Spinner from '../components/Spinner';
import WrapperComponent from './WrapperModal';
export default function AddToolModal({
@@ -23,6 +25,7 @@ export default function AddToolModal({
onToolAdded: (toolId: string) => void;
}) {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
const modalRef = useRef(null);
const [availableTools, setAvailableTools] = React.useState<
AvailableToolType[]
@@ -42,7 +45,7 @@ export default function AddToolModal({
const getAvailableTools = () => {
setLoading(true);
userService
- .getAvailableTools()
+ .getAvailableTools(token)
.then((res) => {
return res.json();
})
@@ -55,14 +58,17 @@ export default function AddToolModal({
const handleAddTool = (tool: AvailableToolType) => {
if (Object.keys(tool.configRequirements).length === 0) {
userService
- .createTool({
- name: tool.name,
- displayName: tool.displayName,
- description: tool.description,
- config: {},
- actions: tool.actions,
- status: true,
- })
+ .createTool(
+ {
+ name: tool.name,
+ displayName: tool.displayName,
+ description: tool.description,
+ config: {},
+ actions: tool.actions,
+ status: true,
+ },
+ token,
+ )
.then((res) => {
if (res.status === 200) {
return res.json();
diff --git a/frontend/src/modals/ConfigToolModal.tsx b/frontend/src/modals/ConfigToolModal.tsx
index 05517c51..e631419a 100644
--- a/frontend/src/modals/ConfigToolModal.tsx
+++ b/frontend/src/modals/ConfigToolModal.tsx
@@ -1,11 +1,13 @@
import React from 'react';
import { useTranslation } from 'react-i18next';
+import { useSelector } from 'react-redux';
-import WrapperModal from './WrapperModal';
+import userService from '../api/services/userService';
import Input from '../components/Input';
import { ActiveState } from '../models/misc';
+import { selectToken } from '../preferences/preferenceSlice';
import { AvailableToolType } from './types';
-import userService from '../api/services/userService';
+import WrapperModal from './WrapperModal';
interface ConfigToolModalProps {
modalState: ActiveState;
@@ -21,18 +23,22 @@ export default function ConfigToolModal({
getUserTools,
}: ConfigToolModalProps) {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
const [authKey, setAuthKey] = React.useState('');
const handleAddTool = (tool: AvailableToolType) => {
userService
- .createTool({
- name: tool.name,
- displayName: tool.displayName,
- description: tool.description,
- config: { token: authKey },
- actions: tool.actions,
- status: true,
- })
+ .createTool(
+ {
+ name: tool.name,
+ displayName: tool.displayName,
+ description: tool.description,
+ config: { token: authKey },
+ actions: tool.actions,
+ status: true,
+ },
+ token,
+ )
.then(() => {
setModalState('INACTIVE');
getUserTools();
diff --git a/frontend/src/modals/CreateAPIKeyModal.tsx b/frontend/src/modals/CreateAPIKeyModal.tsx
index a35efd60..79e2120d 100644
--- a/frontend/src/modals/CreateAPIKeyModal.tsx
+++ b/frontend/src/modals/CreateAPIKeyModal.tsx
@@ -6,7 +6,7 @@ import userService from '../api/services/userService';
import Dropdown from '../components/Dropdown';
import Input from '../components/Input';
import { CreateAPIKeyModalProps, Doc } from '../models/misc';
-import { selectSourceDocs } from '../preferences/preferenceSlice';
+import { selectSourceDocs, selectToken } from '../preferences/preferenceSlice';
import WrapperModal from './WrapperModal';
const embeddingsName =
@@ -18,6 +18,7 @@ export default function CreateAPIKeyModal({
createAPIKey,
}: CreateAPIKeyModalProps) {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
const docs = useSelector(selectSourceDocs);
const [APIKeyName, setAPIKeyName] = React.useState('');
@@ -60,7 +61,7 @@ export default function CreateAPIKeyModal({
React.useEffect(() => {
const handleFetchPrompts = async () => {
try {
- const response = await userService.getPrompts();
+ const response = await userService.getPrompts(token);
if (!response.ok) {
throw new Error('Failed to fetch prompts');
}
diff --git a/frontend/src/modals/JWTModal.tsx b/frontend/src/modals/JWTModal.tsx
new file mode 100644
index 00000000..5f25b217
--- /dev/null
+++ b/frontend/src/modals/JWTModal.tsx
@@ -0,0 +1,47 @@
+import React, { useState } from 'react';
+import { useDispatch } from 'react-redux';
+
+import Input from '../components/Input';
+import { ActiveState } from '../models/misc';
+import WrapperModal from './WrapperModal';
+
+type JWTModalProps = {
+ modalState: ActiveState;
+ handleTokenSubmit: (enteredToken: string) => void;
+};
+
+export default function JWTModal({
+ modalState,
+ handleTokenSubmit,
+}: JWTModalProps) {
+ const [jwtToken, setJwtToken] = useState('');
+
+ if (modalState !== 'ACTIVE') return null;
+
+ return (
+ {}}>
+
+
+ Add JWT Token
+
+
+
+ setJwtToken(e.target.value)}
+ borderVariant="thin"
+ />
+
+
+
+ );
+}
diff --git a/frontend/src/modals/ShareConversationModal.tsx b/frontend/src/modals/ShareConversationModal.tsx
index 3fd444ad..73bb5acd 100644
--- a/frontend/src/modals/ShareConversationModal.tsx
+++ b/frontend/src/modals/ShareConversationModal.tsx
@@ -1,16 +1,21 @@
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
-import {
- selectSourceDocs,
- selectSelectedDocs,
- selectChunks,
- selectPrompt,
-} from '../preferences/preferenceSlice';
+
+import conversationService from '../api/services/conversationService';
+import Spinner from '../assets/spinner.svg';
import Dropdown from '../components/Dropdown';
import ToggleSwitch from '../components/ToggleSwitch';
import { Doc } from '../models/misc';
-import Spinner from '../assets/spinner.svg';
+import {
+ selectChunks,
+ selectPrompt,
+ selectSelectedDocs,
+ selectSourceDocs,
+ selectToken,
+} from '../preferences/preferenceSlice';
+import WrapperModal from './WrapperModal';
+
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
const embeddingsName =
import.meta.env.VITE_EMBEDDINGS_NAME ||
@@ -18,9 +23,6 @@ const embeddingsName =
type StatusType = 'loading' | 'idle' | 'fetched' | 'failed';
-import conversationService from '../api/services/conversationService';
-import WrapperModal from './WrapperModal';
-
export const ShareConversationModal = ({
close,
conversationId,
@@ -29,6 +31,7 @@ export const ShareConversationModal = ({
conversationId: string;
}) => {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
const domain = window.location.origin;
@@ -86,7 +89,7 @@ export const ShareConversationModal = ({
sourcePath && (payload.source = sourcePath.value);
}
conversationService
- .shareConversation(isPromptable, payload)
+ .shareConversation(isPromptable, payload, token)
.then((res) => {
return res.json();
})
diff --git a/frontend/src/preferences/preferenceApi.ts b/frontend/src/preferences/preferenceApi.ts
index 8d21bdcd..d52580e0 100644
--- a/frontend/src/preferences/preferenceApi.ts
+++ b/frontend/src/preferences/preferenceApi.ts
@@ -3,9 +3,9 @@ import userService from '../api/services/userService';
import { Doc, GetDocsResponse } from '../models/misc';
//Fetches all JSON objects from the source. We only use the objects with the "model" property in SelectDocsModal.tsx. Hopefully can clean up the source file later.
-export async function getDocs(): Promise {
+export async function getDocs(token: string | null): Promise {
try {
- const response = await userService.getDocs();
+ const response = await userService.getDocs(token);
const data = await response.json();
const docs: Doc[] = [];
@@ -26,10 +26,11 @@ export async function getDocsWithPagination(
pageNumber = 1,
rowsPerPage = 10,
searchTerm = '',
+ token: string | null,
): Promise {
try {
const query = `sort=${sort}&order=${order}&page=${pageNumber}&rows=${rowsPerPage}&search=${searchTerm}`;
- const response = await userService.getDocsWithPagination(query);
+ const response = await userService.getDocsWithPagination(query, token);
const data = await response.json();
const docs: Doc[] = [];
Array.isArray(data.paginated) &&
@@ -48,12 +49,12 @@ export async function getDocsWithPagination(
}
}
-export async function getConversations(): Promise<{
+export async function getConversations(token: string | null): Promise<{
data: { name: string; id: string }[] | null;
loading: boolean;
}> {
try {
- const response = await conversationService.getConversations();
+ const response = await conversationService.getConversations(token);
const data = await response.json();
const conversations: { name: string; id: string }[] = [];
@@ -100,8 +101,11 @@ export function setLocalRecentDocs(doc: Doc | null): void {
docPath = 'local' + '/' + doc.name + '/';
}
userService
- .checkDocs({
- docs: docPath,
- })
+ .checkDocs(
+ {
+ docs: docPath,
+ },
+ null,
+ )
.then((response) => response.json());
}
diff --git a/frontend/src/preferences/preferenceSlice.ts b/frontend/src/preferences/preferenceSlice.ts
index 8b3064d5..4bca1a37 100644
--- a/frontend/src/preferences/preferenceSlice.ts
+++ b/frontend/src/preferences/preferenceSlice.ts
@@ -19,6 +19,7 @@ export interface Preference {
data: { name: string; id: string }[] | null;
loading: boolean;
};
+ token: string | null;
modalState: ActiveState;
paginatedDocuments: Doc[] | null;
}
@@ -42,6 +43,7 @@ const initialState: Preference = {
data: null,
loading: false,
},
+ token: localStorage.getItem('authToken') || null,
modalState: 'INACTIVE',
paginatedDocuments: null,
};
@@ -65,6 +67,9 @@ export const prefSlice = createSlice({
setConversations: (state, action) => {
state.conversations = action.payload;
},
+ setToken: (state, action) => {
+ state.token = action.payload;
+ },
setPrompt: (state, action) => {
state.prompt = action.payload;
},
@@ -85,6 +90,7 @@ export const {
setSelectedDocs,
setSourceDocs,
setConversations,
+ setToken,
setPrompt,
setChunks,
setTokenLimit,
@@ -157,6 +163,7 @@ export const selectConversations = (state: RootState) =>
state.preference.conversations;
export const selectConversationId = (state: RootState) =>
state.conversation.conversationId;
+export const selectToken = (state: RootState) => state.preference.token;
export const selectPrompt = (state: RootState) => state.preference.prompt;
export const selectChunks = (state: RootState) => state.preference.chunks;
export const selectTokenLimit = (state: RootState) =>
diff --git a/frontend/src/settings/APIKeys.tsx b/frontend/src/settings/APIKeys.tsx
index b892787e..2da36c76 100644
--- a/frontend/src/settings/APIKeys.tsx
+++ b/frontend/src/settings/APIKeys.tsx
@@ -1,17 +1,20 @@
import React, { useState } from 'react';
import { useTranslation } from 'react-i18next';
+import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import Trash from '../assets/trash.svg';
-import CreateAPIKeyModal from '../modals/CreateAPIKeyModal';
-import SaveAPIKeyModal from '../modals/SaveAPIKeyModal';
-import ConfirmationModal from '../modals/ConfirmationModal';
-import { APIKeyData } from './types';
import SkeletonLoader from '../components/SkeletonLoader';
import { useLoaderState } from '../hooks';
+import ConfirmationModal from '../modals/ConfirmationModal';
+import CreateAPIKeyModal from '../modals/CreateAPIKeyModal';
+import SaveAPIKeyModal from '../modals/SaveAPIKeyModal';
+import { selectToken } from '../preferences/preferenceSlice';
+import { APIKeyData } from './types';
export default function APIKeys() {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
const [isCreateModalOpen, setCreateModal] = useState(false);
const [isSaveKeyModalOpen, setSaveKeyModal] = useState(false);
const [newKey, setNewKey] = useState('');
@@ -25,7 +28,7 @@ export default function APIKeys() {
const handleFetchKeys = async () => {
setLoading(true);
try {
- const response = await userService.getAPIKeys();
+ const response = await userService.getAPIKeys(token);
if (!response.ok) {
throw new Error('Failed to fetch API Keys');
}
@@ -41,7 +44,7 @@ export default function APIKeys() {
const handleDeleteKey = (id: string) => {
setLoading(true);
userService
- .deleteAPIKey({ id })
+ .deleteAPIKey({ id }, token)
.then((response) => {
if (!response.ok) {
throw new Error('Failed to delete API Key');
@@ -71,7 +74,7 @@ export default function APIKeys() {
}) => {
setLoading(true);
userService
- .createAPIKey(payload)
+ .createAPIKey(payload, token)
.then((response) => {
if (!response.ok) {
throw new Error('Failed to create API Key');
diff --git a/frontend/src/settings/Analytics.tsx b/frontend/src/settings/Analytics.tsx
index 5ab95bac..a75d9aaf 100644
--- a/frontend/src/settings/Analytics.tsx
+++ b/frontend/src/settings/Analytics.tsx
@@ -1,5 +1,3 @@
-import React, { useState, useEffect } from 'react';
-import { useTranslation } from 'react-i18next';
import {
BarElement,
CategoryScale,
@@ -9,18 +7,21 @@ import {
Title,
Tooltip,
} from 'chart.js';
+import React, { useEffect, useState } from 'react';
import { Bar } from 'react-chartjs-2';
+import { useTranslation } from 'react-i18next';
+import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import Dropdown from '../components/Dropdown';
+import SkeletonLoader from '../components/SkeletonLoader';
+import { useLoaderState } from '../hooks';
+import { selectToken } from '../preferences/preferenceSlice';
import { htmlLegendPlugin } from '../utils/chartUtils';
import { formatDate } from '../utils/dateTimeUtils';
import { APIKeyData } from './types';
-import { useLoaderState } from '../hooks';
import type { ChartData } from 'chart.js';
-import SkeletonLoader from '../components/SkeletonLoader';
-
ChartJS.register(
CategoryScale,
LinearScale,
@@ -32,6 +33,7 @@ ChartJS.register(
export default function Analytics() {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
const filterOptions = [
{ label: t('settings.analytics.filterOptions.hour'), value: 'last_hour' },
@@ -97,7 +99,7 @@ export default function Analytics() {
const fetchChatbots = async () => {
setLoadingChatbots(true);
try {
- const response = await userService.getAPIKeys();
+ const response = await userService.getAPIKeys(token);
if (!response.ok) {
throw new Error('Failed to fetch Chatbots');
}
@@ -113,10 +115,13 @@ export default function Analytics() {
const fetchMessagesData = async (chatbot_id?: string, filter?: string) => {
setLoadingMessages(true);
try {
- const response = await userService.getMessageAnalytics({
- api_key_id: chatbot_id,
- filter_option: filter,
- });
+ const response = await userService.getMessageAnalytics(
+ {
+ api_key_id: chatbot_id,
+ filter_option: filter,
+ },
+ token,
+ );
if (!response.ok) {
throw new Error('Failed to fetch analytics data');
}
@@ -132,10 +137,13 @@ export default function Analytics() {
const fetchTokenData = async (chatbot_id?: string, filter?: string) => {
setLoadingTokens(true);
try {
- const response = await userService.getTokenAnalytics({
- api_key_id: chatbot_id,
- filter_option: filter,
- });
+ const response = await userService.getTokenAnalytics(
+ {
+ api_key_id: chatbot_id,
+ filter_option: filter,
+ },
+ token,
+ );
if (!response.ok) {
throw new Error('Failed to fetch analytics data');
}
@@ -151,10 +159,13 @@ export default function Analytics() {
const fetchFeedbackData = async (chatbot_id?: string, filter?: string) => {
setLoadingFeedback(true);
try {
- const response = await userService.getFeedbackAnalytics({
- api_key_id: chatbot_id,
- filter_option: filter,
- });
+ const response = await userService.getFeedbackAnalytics(
+ {
+ api_key_id: chatbot_id,
+ filter_option: filter,
+ },
+ token,
+ );
if (!response.ok) {
throw new Error('Failed to fetch analytics data');
}
diff --git a/frontend/src/settings/Documents.tsx b/frontend/src/settings/Documents.tsx
index 2b29ef08..0fc03f24 100644
--- a/frontend/src/settings/Documents.tsx
+++ b/frontend/src/settings/Documents.tsx
@@ -1,6 +1,6 @@
import React, { useCallback, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
-import { useDispatch } from 'react-redux';
+import { useDispatch, useSelector } from 'react-redux';
import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg';
@@ -22,6 +22,7 @@ import ConfirmationModal from '../modals/ConfirmationModal';
import { ActiveState, Doc, DocumentsProps } from '../models/misc';
import { getDocs, getDocsWithPagination } from '../preferences/preferenceApi';
import {
+ selectToken,
setPaginatedDocuments,
setSourceDocs,
} from '../preferences/preferenceSlice';
@@ -53,6 +54,7 @@ export default function Documents({
}: DocumentsProps) {
const { t } = useTranslation();
const dispatch = useDispatch();
+ const token = useSelector(selectToken);
const [searchTerm, setSearchTerm] = useState('');
const [modalState, setModalState] = useState('INACTIVE');
@@ -163,6 +165,7 @@ export default function Documents({
page,
rowsPerPg,
searchTerm,
+ token,
)
.then((data) => {
dispatch(setPaginatedDocuments(data ? data.docs : []));
@@ -179,9 +182,9 @@ export default function Documents({
const handleManageSync = (doc: Doc, sync_frequency: string) => {
setLoading(true);
userService
- .manageSync({ source_id: doc.id, sync_frequency })
+ .manageSync({ source_id: doc.id, sync_frequency }, token)
.then(() => {
- return getDocs();
+ return getDocs(token);
})
.then((data) => {
dispatch(setSourceDocs(data));
@@ -190,6 +193,8 @@ export default function Documents({
sortOrder,
currentPage,
rowsPerPage,
+ searchTerm,
+ token,
);
})
.then((paginatedData) => {
@@ -519,6 +524,7 @@ function DocumentChunks({
handleGoBack: () => void;
}) {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
const [isDarkTheme] = useDarkTheme();
const [paginatedChunks, setPaginatedChunks] = useState([]);
const [page, setPage] = useState(1);
@@ -536,7 +542,7 @@ function DocumentChunks({
setLoading(true);
try {
userService
- .getDocumentChunks(document.id ?? '', page, perPage)
+ .getDocumentChunks(document.id ?? '', page, perPage, token)
.then((response) => {
if (!response.ok) {
setLoading(false);
@@ -561,13 +567,16 @@ function DocumentChunks({
const handleAddChunk = (title: string, text: string) => {
try {
userService
- .addChunk({
- id: document.id ?? '',
- text: text,
- metadata: {
- title: title,
+ .addChunk(
+ {
+ id: document.id ?? '',
+ text: text,
+ metadata: {
+ title: title,
+ },
},
- })
+ token,
+ )
.then((response) => {
if (!response.ok) {
throw new Error('Failed to add chunk');
@@ -582,14 +591,17 @@ function DocumentChunks({
const handleUpdateChunk = (title: string, text: string, chunk: ChunkType) => {
try {
userService
- .updateChunk({
- id: document.id ?? '',
- chunk_id: chunk.doc_id,
- text: text,
- metadata: {
- title: title,
+ .updateChunk(
+ {
+ id: document.id ?? '',
+ chunk_id: chunk.doc_id,
+ text: text,
+ metadata: {
+ title: title,
+ },
},
- })
+ token,
+ )
.then((response) => {
if (!response.ok) {
throw new Error('Failed to update chunk');
@@ -604,7 +616,7 @@ function DocumentChunks({
const handleDeleteChunk = (chunk: ChunkType) => {
try {
userService
- .deleteChunk(document.id ?? '', chunk.doc_id)
+ .deleteChunk(document.id ?? '', chunk.doc_id, token)
.then((response) => {
if (!response.ok) {
throw new Error('Failed to delete chunk');
diff --git a/frontend/src/settings/General.tsx b/frontend/src/settings/General.tsx
index 210f6bbc..fa64507e 100644
--- a/frontend/src/settings/General.tsx
+++ b/frontend/src/settings/General.tsx
@@ -8,6 +8,7 @@ import { useDarkTheme } from '../hooks';
import {
selectChunks,
selectPrompt,
+ selectToken,
selectTokenLimit,
setChunks,
setModalStateDeleteConv,
@@ -21,6 +22,7 @@ export default function General() {
t,
i18n: { changeLanguage },
} = useTranslation();
+ const token = useSelector(selectToken);
const themes = [
{ value: 'Light', label: t('settings.general.light') },
{ value: 'Dark', label: t('settings.general.dark') },
@@ -64,7 +66,7 @@ export default function General() {
React.useEffect(() => {
const handleFetchPrompts = async () => {
try {
- const response = await userService.getPrompts();
+ const response = await userService.getPrompts(token);
if (!response.ok) {
throw new Error('Failed to fetch prompts');
}
diff --git a/frontend/src/settings/Logs.tsx b/frontend/src/settings/Logs.tsx
index 24cf3a6d..2507c106 100644
--- a/frontend/src/settings/Logs.tsx
+++ b/frontend/src/settings/Logs.tsx
@@ -1,5 +1,6 @@
-import React, { useState, useEffect, useRef, useCallback } from 'react';
+import React, { useCallback, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
+import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import ChevronRight from '../assets/chevron-right.svg';
@@ -7,10 +8,12 @@ import CopyButton from '../components/CopyButton';
import Dropdown from '../components/Dropdown';
import SkeletonLoader from '../components/SkeletonLoader';
import { useLoaderState } from '../hooks';
+import { selectToken } from '../preferences/preferenceSlice';
import { APIKeyData, LogData } from './types';
export default function Logs() {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
const [chatbots, setChatbots] = useState([]);
const [selectedChatbot, setSelectedChatbot] = useState();
const [logs, setLogs] = useState([]);
@@ -22,7 +25,7 @@ export default function Logs() {
const fetchChatbots = async () => {
setLoadingChatbots(true);
try {
- const response = await userService.getAPIKeys();
+ const response = await userService.getAPIKeys(token);
if (!response.ok) {
throw new Error('Failed to fetch Chatbots');
}
@@ -38,11 +41,14 @@ export default function Logs() {
const fetchLogs = async () => {
setLoadingLogs(true);
try {
- const response = await userService.getLogs({
- page: page,
- api_key_id: selectedChatbot?.id,
- page_size: 10,
- });
+ const response = await userService.getLogs(
+ {
+ page: page,
+ api_key_id: selectedChatbot?.id,
+ page_size: 10,
+ },
+ token,
+ );
if (!response.ok) {
throw new Error('Failed to fetch logs');
}
diff --git a/frontend/src/settings/Prompts.tsx b/frontend/src/settings/Prompts.tsx
index 654b610a..33540296 100644
--- a/frontend/src/settings/Prompts.tsx
+++ b/frontend/src/settings/Prompts.tsx
@@ -1,9 +1,11 @@
import React from 'react';
import { useTranslation } from 'react-i18next';
+import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import Dropdown from '../components/Dropdown';
import { ActiveState, PromptProps } from '../models/misc';
+import { selectToken } from '../preferences/preferenceSlice';
import PromptsModal from '../preferences/PromptsModal';
export default function Prompts({
@@ -24,6 +26,7 @@ export default function Prompts({
setEditPromptName(name);
onSelectPrompt(name, id, type);
};
+ const token = useSelector(selectToken);
const [newPromptName, setNewPromptName] = React.useState('');
const [newPromptContent, setNewPromptContent] = React.useState('');
const [editPromptName, setEditPromptName] = React.useState('');
@@ -39,10 +42,13 @@ export default function Prompts({
const handleAddPrompt = async () => {
try {
- const response = await userService.createPrompt({
- name: newPromptName,
- content: newPromptContent,
- });
+ const response = await userService.createPrompt(
+ {
+ name: newPromptName,
+ content: newPromptContent,
+ },
+ token,
+ );
if (!response.ok) {
throw new Error('Failed to add prompt');
}
@@ -65,7 +71,7 @@ export default function Prompts({
const handleDeletePrompt = (id: string) => {
setPrompts(prompts.filter((prompt) => prompt.id !== id));
userService
- .deletePrompt({ id })
+ .deletePrompt({ id }, token)
.then((response) => {
if (!response.ok) {
throw new Error('Failed to delete prompt');
@@ -81,7 +87,7 @@ export default function Prompts({
const handleFetchPromptContent = async (id: string) => {
try {
- const response = await userService.getSinglePrompt(id);
+ const response = await userService.getSinglePrompt(id, token);
if (!response.ok) {
throw new Error('Failed to fetch prompt content');
}
@@ -94,11 +100,14 @@ export default function Prompts({
const handleSaveChanges = (id: string, type: string) => {
userService
- .updatePrompt({
- id: id,
- name: editPromptName,
- content: editPromptContent,
- })
+ .updatePrompt(
+ {
+ id: id,
+ name: editPromptName,
+ content: editPromptContent,
+ },
+ token,
+ )
.then((response) => {
if (!response.ok) {
throw new Error('Failed to update prompt');
diff --git a/frontend/src/settings/ToolConfig.tsx b/frontend/src/settings/ToolConfig.tsx
index af57db21..d75a3852 100644
--- a/frontend/src/settings/ToolConfig.tsx
+++ b/frontend/src/settings/ToolConfig.tsx
@@ -1,4 +1,6 @@
import React from 'react';
+import { useSelector } from 'react-redux';
+
import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg';
import CircleCheck from '../assets/circle-check.svg';
@@ -9,6 +11,7 @@ import Input from '../components/Input';
import ToggleSwitch from '../components/ToggleSwitch';
import AddActionModal from '../modals/AddActionModal';
import { ActiveState } from '../models/misc';
+import { selectToken } from '../preferences/preferenceSlice';
import { APIActionType, APIToolType, UserToolType } from './types';
import { useTranslation } from 'react-i18next';
@@ -21,6 +24,7 @@ export default function ToolConfig({
setTool: (tool: UserToolType | APIToolType) => void;
handleGoBack: () => void;
}) {
+ const token = useSelector(selectToken);
const [authKey, setAuthKey] = React.useState(
'token' in tool.config ? tool.config.token : '',
);
@@ -57,22 +61,25 @@ export default function ToolConfig({
const handleSaveChanges = () => {
userService
- .updateTool({
- id: tool.id,
- name: tool.name,
- displayName: tool.displayName,
- description: tool.description,
- config: tool.name === 'api_tool' ? tool.config : { token: authKey },
- actions: 'actions' in tool ? tool.actions : [],
- status: tool.status,
- })
+ .updateTool(
+ {
+ id: tool.id,
+ name: tool.name,
+ displayName: tool.displayName,
+ description: tool.description,
+ config: tool.name === 'api_tool' ? tool.config : { token: authKey },
+ actions: 'actions' in tool ? tool.actions : [],
+ status: tool.status,
+ },
+ token,
+ )
.then(() => {
handleGoBack();
});
};
const handleDelete = () => {
- userService.deleteTool({ id: tool.id }).then(() => {
+ userService.deleteTool({ id: tool.id }, token).then(() => {
handleGoBack();
});
};
diff --git a/frontend/src/settings/Tools.tsx b/frontend/src/settings/Tools.tsx
index 7432ecf0..b42195ed 100644
--- a/frontend/src/settings/Tools.tsx
+++ b/frontend/src/settings/Tools.tsx
@@ -1,18 +1,22 @@
import React from 'react';
import { useTranslation } from 'react-i18next';
+import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import CogwheelIcon from '../assets/cogwheel.svg';
import Input from '../components/Input';
import Spinner from '../components/Spinner';
+import ToggleSwitch from '../components/ToggleSwitch';
import AddToolModal from '../modals/AddToolModal';
import { ActiveState } from '../models/misc';
+import { selectToken } from '../preferences/preferenceSlice';
import ToolConfig from './ToolConfig';
import { APIToolType, UserToolType } from './types';
-import ToggleSwitch from '../components/ToggleSwitch';
export default function Tools() {
const { t } = useTranslation();
+ const token = useSelector(selectToken);
+
const [searchTerm, setSearchTerm] = React.useState('');
const [addToolModalState, setAddToolModalState] =
React.useState('INACTIVE');
@@ -25,7 +29,7 @@ export default function Tools() {
const getUserTools = () => {
setLoading(true);
userService
- .getUserTools()
+ .getUserTools(token)
.then((res) => {
return res.json();
})
@@ -41,7 +45,7 @@ export default function Tools() {
const updateToolStatus = (toolId: string, newStatus: boolean) => {
userService
- .updateToolStatus({ id: toolId, status: newStatus })
+ .updateToolStatus({ id: toolId, status: newStatus }, token)
.then(() => {
setUserTools((prevTools) =>
prevTools.map((tool) =>
@@ -65,7 +69,7 @@ export default function Tools() {
const handleToolAdded = (toolId: string) => {
userService
- .getUserTools()
+ .getUserTools(token)
.then((res) => res.json())
.then((data) => {
const newTool = data.tools.find(
diff --git a/frontend/src/settings/index.tsx b/frontend/src/settings/index.tsx
index 918e4d15..cd504858 100644
--- a/frontend/src/settings/index.tsx
+++ b/frontend/src/settings/index.tsx
@@ -11,6 +11,7 @@ import {
selectSourceDocs,
setPaginatedDocuments,
setSourceDocs,
+ selectToken,
} from '../preferences/preferenceSlice';
import Analytics from './Analytics';
import APIKeys from './APIKeys';
@@ -28,6 +29,7 @@ export default function Settings() {
null,
);
+ const token = useSelector(selectToken);
const documents = useSelector(selectSourceDocs);
const paginatedDocuments = useSelector(selectPaginatedDocuments);
const updateWidgetScreenshot = (screenshot: File | null) => {
@@ -41,7 +43,7 @@ export default function Settings() {
const handleDeleteClick = (index: number, doc: Doc) => {
userService
- .deletePath(doc.id ?? '')
+ .deletePath(doc.id ?? '', token)
.then((response) => {
if (response.ok && documents) {
if (paginatedDocuments) {
diff --git a/frontend/src/store.ts b/frontend/src/store.ts
index 8f426ed6..02aa9a68 100644
--- a/frontend/src/store.ts
+++ b/frontend/src/store.ts
@@ -16,6 +16,7 @@ const doc = localStorage.getItem('DocsGPTRecentDocs');
const preloadedState: { preference: Preference } = {
preference: {
apiKey: key ?? '',
+ token: localStorage.getItem('authToken') ?? null,
prompt:
prompt !== null
? JSON.parse(prompt)
diff --git a/frontend/src/upload/Upload.tsx b/frontend/src/upload/Upload.tsx
index df06af2f..e70c930f 100644
--- a/frontend/src/upload/Upload.tsx
+++ b/frontend/src/upload/Upload.tsx
@@ -9,21 +9,22 @@ import WebsiteCollect from '../assets/website_collect.svg';
import Dropdown from '../components/Dropdown';
import Input from '../components/Input';
import ToggleSwitch from '../components/ToggleSwitch';
+import WrapperModal from '../modals/WrapperModal';
import { ActiveState, Doc } from '../models/misc';
import { getDocs } from '../preferences/preferenceApi';
import {
+ selectSourceDocs,
+ selectToken,
setSelectedDocs,
setSourceDocs,
- selectSourceDocs,
} from '../preferences/preferenceSlice';
-import WrapperModal from '../modals/WrapperModal';
+import { IngestorDefaultConfigs } from '../upload/types/ingestor';
import {
- IngestorType,
+ FormField,
IngestorConfig,
IngestorFormSchemas,
- FormField,
+ IngestorType,
} from './types/ingestor';
-import { IngestorDefaultConfigs } from '../upload/types/ingestor';
function Upload({
receivedFile = [],
@@ -40,6 +41,7 @@ function Upload({
close: () => void;
onSuccessfulUpload?: () => void;
}) {
+ const token = useSelector(selectToken);
const [docName, setDocName] = useState(receivedFile[0]?.name);
const [remoteName, setRemoteName] = useState('');
const [files, setfiles] = useState(receivedFile);
@@ -297,12 +299,12 @@ function Upload({
if ((progress?.percentage ?? 0) < 100) {
timeoutID = setTimeout(() => {
userService
- .getTaskStatus(progress?.taskId as string)
+ .getTaskStatus(progress?.taskId as string, null)
.then((data) => data.json())
.then((data) => {
if (data.status == 'SUCCESS') {
if (data.result.limited === true) {
- getDocs().then((data) => {
+ getDocs(token).then((data) => {
dispatch(setSourceDocs(data));
dispatch(
setSelectedDocs(
@@ -322,7 +324,7 @@ function Upload({
},
);
} else {
- getDocs().then((data) => {
+ getDocs(token).then((data) => {
dispatch(setSourceDocs(data));
const docIds = new Set(
(Array.isArray(sourceDocs) &&
@@ -413,6 +415,7 @@ function Upload({
}, 3000);
};
xhr.open('POST', `${apiHost + '/api/upload'}`);
+ xhr.setRequestHeader('Authorization', `Bearer ${token}`);
xhr.send(formData);
};