diff --git a/application/agents/base.py b/application/agents/base.py index 7e36c991..d0f972a9 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -9,10 +9,21 @@ from application.llm.llm_creator import LLMCreator class BaseAgent: - def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None): + def __init__( + self, + endpoint, + llm_name, + gpt_model, + api_key, + user_api_key=None, + decoded_token=None, + ): self.endpoint = endpoint self.llm = LLMCreator.create_llm( - llm_name, api_key=api_key, user_api_key=user_api_key + llm_name, + api_key=api_key, + user_api_key=user_api_key, + decoded_token=decoded_token, ) self.llm_handler = get_llm_handler(llm_name) self.gpt_model = gpt_model diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 8848c6f6..2752c833 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -17,8 +17,12 @@ class ClassicAgent(BaseAgent): user_api_key=None, prompt="", chat_history=None, + decoded_token=None, ): - super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key) + super().__init__( + endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token + ) + self.user = decoded_token.get("sub") self.prompt = prompt self.chat_history = chat_history if chat_history is not None else [] @@ -73,7 +77,7 @@ class ClassicAgent(BaseAgent): ) messages_combine.append({"role": "user", "content": query}) - tools_dict = self._get_user_tools() + tools_dict = self._get_user_tools(self.user) self._prepare_tools(tools_dict) resp = self._llm_gen(messages_combine, log_context) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 7f88ba0f..34081784 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -124,6 +124,7 @@ def save_conversation( source_log_docs, tool_calls, llm, + decoded_token, index=None, api_key=None, ): @@ -182,7 +183,7 @@ def save_conversation( completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) conversation_data = { - "user": "local", + "user": decoded_token.get("sub"), "date": datetime.datetime.utcnow(), "name": completion, "queries": [ @@ -223,6 +224,7 @@ def complete_stream( retriever, conversation_id, user_api_key, + decoded_token, isNoneDoc=False, index=None, should_save_conversation=True, @@ -262,7 +264,10 @@ def complete_stream( doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + settings.LLM_NAME, + api_key=settings.API_KEY, + user_api_key=user_api_key, + decoded_token=decoded_token, ) if should_save_conversation: @@ -273,6 +278,7 @@ def complete_stream( source_log_docs, tool_calls, llm, + decoded_token, index, api_key=user_api_key, ) @@ -288,7 +294,7 @@ def complete_stream( { "action": "stream_answer", "level": "info", - "user": "local", + "user": decoded_token.get("sub"), "api_key": user_api_key, "question": question, "response": response_full, @@ -383,15 +389,21 @@ class Stream(Resource): source = {"active_docs": data_key.get("source")} retriever_name = data_key.get("retriever", retriever_name) user_api_key = data["api_key"] + decoded_token = {"sub": data_key.get("user")} elif "active_docs" in data: source = {"active_docs": data["active_docs"]} retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None + decoded_token = request.decoded_token else: source = {} user_api_key = None + decoded_token = request.decoded_token + + if not decoded_token: + return make_response({"error": "Unauthorized"}, 401) logger.info( f"/stream - request_data: {data}, source: {source}", @@ -411,6 +423,7 @@ class Stream(Resource): user_api_key=user_api_key, prompt=prompt, chat_history=history, + decoded_token=decoded_token, ) retriever = RetrieverCreator.create_retriever( @@ -422,6 +435,7 @@ class Stream(Resource): token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, + decoded_token=decoded_token, ) return Response( @@ -431,6 +445,7 @@ class Stream(Resource): retriever=retriever, conversation_id=conversation_id, user_api_key=user_api_key, + decoded_token=decoded_token, isNoneDoc=data.get("isNoneDoc"), index=index, should_save_conversation=save_conv, @@ -523,13 +538,21 @@ class Answer(Resource): source = {"active_docs": data_key.get("source")} retriever_name = data_key.get("retriever", retriever_name) user_api_key = data["api_key"] + decoded_token = {"sub": data_key.get("user")} + elif "active_docs" in data: source = {"active_docs": data["active_docs"]} retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None + decoded_token = request.decoded_token + else: source = {} user_api_key = None + decoded_token = request.decoded_token + + if not decoded_token: + return make_response({"error": "Unauthorized"}, 401) prompt = get_prompt(prompt_id) @@ -547,6 +570,7 @@ class Answer(Resource): user_api_key=user_api_key, prompt=prompt, chat_history=history, + decoded_token=decoded_token, ) retriever = RetrieverCreator.create_retriever( @@ -558,6 +582,7 @@ class Answer(Resource): token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, + decoded_token=decoded_token, ) response_full = "" @@ -571,6 +596,7 @@ class Answer(Resource): retriever=retriever, conversation_id=conversation_id, user_api_key=user_api_key, + decoded_token=decoded_token, isNoneDoc=data.get("isNoneDoc"), index=None, should_save_conversation=False, @@ -604,7 +630,10 @@ class Answer(Resource): doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + settings.LLM_NAME, + api_key=settings.API_KEY, + user_api_key=user_api_key, + decoded_token=decoded_token, ) result = {"answer": response_full, "sources": source_log_docs} @@ -616,6 +645,7 @@ class Answer(Resource): source_log_docs, tool_calls, llm, + decoded_token, api_key=user_api_key, ) ) @@ -625,7 +655,7 @@ class Answer(Resource): { "action": "api_answer", "level": "info", - "user": "local", + "user": decoded_token.get("sub"), "api_key": user_api_key, "question": question, "response": response_full, @@ -694,12 +724,20 @@ class Search(Resource): chunks = int(data_key.get("chunks", 2)) source = {"active_docs": data_key.get("source")} user_api_key = data["api_key"] + decoded_token = {"sub": data_key.get("user")} + elif "active_docs" in data: source = {"active_docs": data["active_docs"]} user_api_key = None + decoded_token = request.decoded_token + else: source = {} user_api_key = None + decoded_token = request.decoded_token + + if not decoded_token: + return make_response({"error": "Unauthorized"}, 401) logger.info( f"/api/answer - request_data: {data}, source: {source}", @@ -715,6 +753,7 @@ class Search(Resource): token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, + decoded_token=decoded_token, ) docs = retriever.search(question) @@ -724,7 +763,7 @@ class Search(Resource): { "action": "api_search", "level": "info", - "user": "local", + "user": decoded_token.get("sub"), "api_key": user_api_key, "question": question, "sources": docs, diff --git a/application/api/user/routes.py b/application/api/user/routes.py index d7fb4d89..f3599c7e 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -15,7 +15,6 @@ from werkzeug.utils import secure_filename from application.agents.tools.tool_manager import ToolManager from application.api.user.tasks import ingest, ingest_remote - from application.core.mongo_db import MongoDB from application.core.settings import settings from application.extensions import api @@ -68,6 +67,21 @@ def generate_date_range(start_date, end_date): } +def get_vector_store(source_id): + """ + Get the Vector Store + Args: + source_id (str): source id of the document + """ + + store = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, + source_id=source_id, + embeddings_key=os.getenv("EMBEDDINGS_KEY"), + ) + return store + + @user_ns.route("/api/delete_conversation") class DeleteConversation(Resource): @api.doc( @@ -75,6 +89,9 @@ class DeleteConversation(Resource): params={"id": "The ID of the conversation to delete"}, ) def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) conversation_id = request.args.get("id") if not conversation_id: return make_response( @@ -82,7 +99,9 @@ class DeleteConversation(Resource): ) try: - conversations_collection.delete_one({"_id": ObjectId(conversation_id)}) + conversations_collection.delete_one( + {"_id": ObjectId(conversation_id), "user": decoded_token["sub"]} + ) except Exception as err: current_app.logger.error(f"Error deleting conversation: {err}") return make_response(jsonify({"success": False}), 400) @@ -95,7 +114,10 @@ class DeleteAllConversations(Resource): description="Deletes all conversations for a specific user", ) def get(self): - user_id = "local" + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user_id = decoded_token.get("sub") try: conversations_collection.delete_many({"user": user_id}) except Exception as err: @@ -110,11 +132,18 @@ class GetConversations(Resource): description="Retrieve a list of the latest 30 conversations (excluding API key conversations)", ) def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) try: - conversations = conversations_collection.find( - {"api_key": {"$exists": False}} - ).sort("date", -1).limit(30) - + conversations = ( + conversations_collection.find( + {"api_key": {"$exists": False}, "user": decoded_token.get("sub")} + ) + .sort("date", -1) + .limit(30) + ) + list_conversations = [ {"id": str(conversation["_id"]), "name": conversation["name"]} for conversation in conversations @@ -132,6 +161,9 @@ class GetSingleConversation(Resource): params={"id": "The conversation ID"}, ) def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) conversation_id = request.args.get("id") if not conversation_id: return make_response( @@ -140,7 +172,7 @@ class GetSingleConversation(Resource): try: conversation = conversations_collection.find_one( - {"_id": ObjectId(conversation_id)} + {"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")} ) if not conversation: return make_response(jsonify({"status": "not found"}), 404) @@ -167,6 +199,9 @@ class UpdateConversationName(Resource): description="Updates the name of a conversation", ) def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) data = request.get_json() required_fields = ["id", "name"] missing_fields = check_required_fields(data, required_fields) @@ -175,7 +210,8 @@ class UpdateConversationName(Resource): try: conversations_collection.update_one( - {"_id": ObjectId(data["id"])}, {"$set": {"name": data["name"]}} + {"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")}, + {"$set": {"name": data["name"]}}, ) except Exception as err: current_app.logger.error(f"Error updating conversation name: {err}") @@ -210,6 +246,9 @@ class SubmitFeedback(Resource): description="Submit feedback for a conversation", ) def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) data = request.get_json() required_fields = ["feedback", "conversation_id", "question_index"] missing_fields = check_required_fields(data, required_fields) @@ -222,12 +261,13 @@ class SubmitFeedback(Resource): conversations_collection.update_one( { "_id": ObjectId(data["conversation_id"]), + "user": decoded_token.get("sub"), f"queries.{data['question_index']}": {"$exists": True}, }, { "$unset": { f"queries.{data['question_index']}.feedback": "", - f"queries.{data['question_index']}.feedback_timestamp": "" + f"queries.{data['question_index']}.feedback_timestamp": "", } }, ) @@ -236,12 +276,17 @@ class SubmitFeedback(Resource): conversations_collection.update_one( { "_id": ObjectId(data["conversation_id"]), + "user": decoded_token.get("sub"), f"queries.{data['question_index']}": {"$exists": True}, }, { "$set": { - f"queries.{data['question_index']}.feedback": data["feedback"], - f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(datetime.timezone.utc) + f"queries.{data['question_index']}.feedback": data[ + "feedback" + ], + f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now( + datetime.timezone.utc + ), } }, ) @@ -284,13 +329,18 @@ class DeleteOldIndexes(Resource): params={"source_id": "The source ID to delete"}, ) def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) source_id = request.args.get("source_id") if not source_id: return make_response( jsonify({"success": False, "message": "Missing required fields"}), 400 ) - doc = sources_collection.find_one({"_id": ObjectId(source_id), "user": "local"}) + doc = sources_collection.find_one( + {"_id": ObjectId(source_id), "user": decoded_token.get("sub")} + ) if not doc: return make_response(jsonify({"status": "not found"}), 404) try: @@ -328,6 +378,9 @@ class UploadFile(Resource): description="Uploads a file to be vectorized and indexed", ) def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) data = request.form files = request.files.getlist("file") required_fields = ["user", "name"] @@ -343,7 +396,7 @@ class UploadFile(Resource): 400, ) - user = secure_filename(request.form["user"]) + user = secure_filename(decoded_token.get("sub")) job_name = secure_filename(request.form["name"]) try: save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name) @@ -443,6 +496,9 @@ class UploadRemote(Resource): description="Uploads remote source for vectorization", ) def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) data = request.form required_fields = ["user", "source", "name", "data"] missing_fields = check_required_fields(data, required_fields) @@ -463,7 +519,7 @@ class UploadRemote(Resource): task = ingest_remote.delay( source_data=source_data, job_name=data["name"], - user=data["user"], + user=decoded_token.get("sub"), loader=data["source"], ) except Exception as err: @@ -519,7 +575,10 @@ class RedirectToSources(Resource): class PaginatedSources(Resource): @api.doc(description="Get document with pagination, sorting and filtering") def get(self): - user = "local" + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") sort_field = request.args.get("sort", "date") # Default to 'date' sort_order = request.args.get("order", "desc") # Default to 'desc' page = int(request.args.get("page", 1)) # Default to 1 @@ -584,7 +643,10 @@ class PaginatedSources(Resource): class CombinedJson(Resource): @api.doc(description="Provide JSON file with combined available indexes") def get(self): - user = "local" + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = [ { "name": "Default", @@ -685,13 +747,16 @@ class CreatePrompt(Resource): @api.expect(create_prompt_model) @api.doc(description="Create a new prompt") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) data = request.get_json() required_fields = ["content", "name"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields - user = "local" + user = decoded_token.get("sub") try: resp = prompts_collection.insert_one( @@ -713,7 +778,10 @@ class CreatePrompt(Resource): class GetPrompts(Resource): @api.doc(description="Get all prompts for the user") def get(self): - user = "local" + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") try: prompts = prompts_collection.find({"user": user}) list_prompts = [ @@ -741,6 +809,10 @@ class GetPrompts(Resource): class GetSinglePrompt(Resource): @api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID") def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") prompt_id = request.args.get("id") if not prompt_id: return make_response( @@ -771,7 +843,9 @@ class GetSinglePrompt(Resource): chat_reduce_strict = f.read() return make_response(jsonify({"content": chat_reduce_strict}), 200) - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)}) + prompt = prompts_collection.find_one( + {"_id": ObjectId(prompt_id), "user": user} + ) except Exception as err: current_app.logger.error(f"Error retrieving prompt: {err}") return make_response(jsonify({"success": False}), 400) @@ -789,6 +863,10 @@ class DeletePrompt(Resource): @api.expect(delete_prompt_model) @api.doc(description="Delete a prompt by ID") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id"] missing_fields = check_required_fields(data, required_fields) @@ -796,7 +874,7 @@ class DeletePrompt(Resource): return missing_fields try: - prompts_collection.delete_one({"_id": ObjectId(data["id"])}) + prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user}) except Exception as err: current_app.logger.error(f"Error deleting prompt: {err}") return make_response(jsonify({"success": False}), 400) @@ -820,6 +898,10 @@ class UpdatePrompt(Resource): @api.expect(update_prompt_model) @api.doc(description="Update an existing prompt") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id", "name", "content"] missing_fields = check_required_fields(data, required_fields) @@ -828,7 +910,7 @@ class UpdatePrompt(Resource): try: prompts_collection.update_one( - {"_id": ObjectId(data["id"])}, + {"_id": ObjectId(data["id"]), "user": user}, {"$set": {"name": data["name"], "content": data["content"]}}, ) except Exception as err: @@ -842,7 +924,10 @@ class UpdatePrompt(Resource): class GetApiKeys(Resource): @api.doc(description="Retrieve API keys for the user") def get(self): - user = "local" + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") try: keys = api_key_collection.find({"user": user}) list_keys = [] @@ -889,13 +974,16 @@ class CreateApiKey(Resource): @api.expect(create_api_key_model) @api.doc(description="Create a new API key") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["name", "prompt_id", "chunks"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields - user = "local" try: key = str(uuid.uuid4()) new_api_key = { @@ -929,6 +1017,10 @@ class DeleteApiKey(Resource): @api.expect(delete_api_key_model) @api.doc(description="Delete an API key by ID") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id"] missing_fields = check_required_fields(data, required_fields) @@ -936,7 +1028,9 @@ class DeleteApiKey(Resource): return missing_fields try: - result = api_key_collection.delete_one({"_id": ObjectId(data["id"])}) + result = api_key_collection.delete_one( + {"_id": ObjectId(data["id"]), "user": user} + ) if result.deleted_count == 0: return {"success": False, "message": "API Key not found"}, 404 except Exception as err: @@ -963,6 +1057,10 @@ class ShareConversation(Resource): @api.expect(share_conversation_model) @api.doc(description="Share a conversation") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["conversation_id"] missing_fields = check_required_fields(data, required_fields) @@ -974,8 +1072,6 @@ class ShareConversation(Resource): return make_response( jsonify({"success": False, "message": "isPromptable is required"}), 400 ) - - user = data.get("user", "local") conversation_id = data["conversation_id"] try: @@ -1211,7 +1307,13 @@ class GetMessageAnalytics(Resource): required=False, description="Filter option for analytics", default="last_30_days", - enum=["last_hour", "last_24_hour", "last_7_days", "last_15_days", "last_30_days"], + enum=[ + "last_hour", + "last_24_hour", + "last_7_days", + "last_15_days", + "last_30_days", + ], ), }, ) @@ -1219,13 +1321,19 @@ class GetMessageAnalytics(Resource): @api.expect(get_message_analytics_model) @api.doc(description="Get message analytics based on filter option") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() api_key_id = data.get("api_key_id") filter_option = data.get("filter_option", "last_30_days") try: api_key = ( - api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] + api_key_collection.find_one( + {"_id": ObjectId(api_key_id), "user": user} + )["key"] if api_key_id else None ) @@ -1244,9 +1352,9 @@ class GetMessageAnalytics(Resource): else: if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: filter_days = ( - 6 if filter_option == "last_7_days" - else 14 if filter_option == "last_15_days" - else 29 + 6 + if filter_option == "last_7_days" + else 14 if filter_option == "last_15_days" else 29 ) else: return make_response( @@ -1254,41 +1362,40 @@ class GetMessageAnalytics(Resource): ) start_date = end_date - datetime.timedelta(days=filter_days) start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) - end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999) + end_date = end_date.replace( + hour=23, minute=59, second=59, microsecond=999999 + ) group_format = "%Y-%m-%d" try: + match_stage = { + "$match": { + "user": user, + } + } + if api_key: + match_stage["$match"]["api_key"] = api_key + pipeline = [ - # Initial match for API key if provided - { - "$match": { - "api_key": api_key if api_key else {"$exists": False} - } - }, + match_stage, {"$unwind": "$queries"}, - # Match queries within the time range { "$match": { - "queries.timestamp": { - "$gte": start_date, - "$lte": end_date - } + "queries.timestamp": {"$gte": start_date, "$lte": end_date} } }, - # Group by formatted timestamp { "$group": { "_id": { "$dateToString": { "format": group_format, - "date": "$queries.timestamp" + "date": "$queries.timestamp", } }, - "count": {"$sum": 1} + "count": {"$sum": 1}, } }, - # Sort by timestamp - {"$sort": {"_id": 1}} + {"$sort": {"_id": 1}}, ] message_data = conversations_collection.aggregate(pipeline) @@ -1338,13 +1445,19 @@ class GetTokenAnalytics(Resource): @api.expect(get_token_analytics_model) @api.doc(description="Get token analytics data") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() api_key_id = data.get("api_key_id") filter_option = data.get("filter_option", "last_30_days") try: api_key = ( - api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] + api_key_collection.find_one( + {"_id": ObjectId(api_key_id), "user": user} + )["key"] if api_key_id else None ) @@ -1426,13 +1539,12 @@ class GetTokenAnalytics(Resource): try: match_stage = { "$match": { + "user_id": user, "timestamp": {"$gte": start_date, "$lte": end_date}, } } if api_key: match_stage["$match"]["api_key"] = api_key - else: - match_stage["$match"]["api_key"] = {"$exists": False} token_usage_data = token_usage_collection.aggregate( [ @@ -1492,13 +1604,19 @@ class GetFeedbackAnalytics(Resource): @api.expect(get_feedback_analytics_model) @api.doc(description="Get feedback analytics data") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() api_key_id = data.get("api_key_id") filter_option = data.get("filter_option", "last_30_days") try: api_key = ( - api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] + api_key_collection.find_one( + {"_id": ObjectId(api_key_id), "user": user} + )["key"] if api_key_id else None ) @@ -1511,11 +1629,21 @@ class GetFeedbackAnalytics(Resource): if filter_option == "last_hour": start_date = end_date - datetime.timedelta(hours=1) group_format = "%Y-%m-%d %H:%M:00" - date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}} + date_field = { + "$dateToString": { + "format": group_format, + "date": "$queries.feedback_timestamp", + } + } elif filter_option == "last_24_hour": start_date = end_date - datetime.timedelta(hours=24) group_format = "%Y-%m-%d %H:00" - date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}} + date_field = { + "$dateToString": { + "format": group_format, + "date": "$queries.feedback_timestamp", + } + } else: if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: filter_days = ( @@ -1533,21 +1661,26 @@ class GetFeedbackAnalytics(Resource): hour=23, minute=59, second=59, microsecond=999999 ) group_format = "%Y-%m-%d" - date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}} + date_field = { + "$dateToString": { + "format": group_format, + "date": "$queries.feedback_timestamp", + } + } try: match_stage = { "$match": { - "queries.feedback_timestamp": {"$gte": start_date, "$lte": end_date}, - "queries.feedback": {"$exists": True} + "queries.feedback_timestamp": { + "$gte": start_date, + "$lte": end_date, + }, + "queries.feedback": {"$exists": True}, } } if api_key: match_stage["$match"]["api_key"] = api_key - else: - match_stage["$match"]["api_key"] = {"$exists": False} - # Unwind the queries array to process each query separately pipeline = [ match_stage, {"$unwind": "$queries"}, @@ -1634,6 +1767,10 @@ class GetUserLogs(Resource): @api.expect(get_user_logs_model) @api.doc(description="Get user logs with pagination") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() page = int(data.get("page", 1)) api_key_id = data.get("api_key_id") @@ -1650,7 +1787,7 @@ class GetUserLogs(Resource): current_app.logger.error(f"Error getting API key: {err}") return make_response(jsonify({"success": False}), 400) - query = {} + query = {"user": user} if api_key: query = {"api_key": api_key} @@ -1708,6 +1845,10 @@ class ManageSync(Resource): @api.expect(manage_sync_model) @api.doc(description="Manage sync frequency for sources") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["source_id", "sync_frequency"] missing_fields = check_required_fields(data, required_fields) @@ -1727,7 +1868,7 @@ class ManageSync(Resource): sources_collection.update_one( { "_id": ObjectId(source_id), - "user": "local", + "user": user, }, update_data, ) @@ -1804,7 +1945,10 @@ class GetTools(Resource): @api.doc(description="Get tools created by a user") def get(self): try: - user = "local" + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") tools = user_tools_collection.find({"user": user}) user_tools = [] for tool in tools: @@ -1847,6 +1991,10 @@ class CreateTool(Resource): ) @api.doc(description="Create a new tool") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = [ "name", @@ -1860,7 +2008,6 @@ class CreateTool(Resource): if missing_fields: return missing_fields - user = "local" transformed_actions = [] for action in data["actions"]: action["active"] = True @@ -1911,6 +2058,10 @@ class UpdateTool(Resource): ) @api.doc(description="Update a tool by ID") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id"] missing_fields = check_required_fields(data, required_fields) @@ -1946,7 +2097,7 @@ class UpdateTool(Resource): update_data["status"] = data["status"] user_tools_collection.update_one( - {"_id": ObjectId(data["id"]), "user": "local"}, + {"_id": ObjectId(data["id"]), "user": user}, {"$set": update_data}, ) except Exception as err: @@ -1971,6 +2122,10 @@ class UpdateToolConfig(Resource): ) @api.doc(description="Update the configuration of a tool") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id", "config"] missing_fields = check_required_fields(data, required_fields) @@ -1979,7 +2134,7 @@ class UpdateToolConfig(Resource): try: user_tools_collection.update_one( - {"_id": ObjectId(data["id"])}, + {"_id": ObjectId(data["id"]), "user": user}, {"$set": {"config": data["config"]}}, ) except Exception as err: @@ -2006,6 +2161,10 @@ class UpdateToolActions(Resource): ) @api.doc(description="Update the actions of a tool") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id", "actions"] missing_fields = check_required_fields(data, required_fields) @@ -2014,7 +2173,7 @@ class UpdateToolActions(Resource): try: user_tools_collection.update_one( - {"_id": ObjectId(data["id"])}, + {"_id": ObjectId(data["id"]), "user": user}, {"$set": {"actions": data["actions"]}}, ) except Exception as err: @@ -2039,6 +2198,10 @@ class UpdateToolStatus(Resource): ) @api.doc(description="Update the status of a tool") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id", "status"] missing_fields = check_required_fields(data, required_fields) @@ -2047,7 +2210,7 @@ class UpdateToolStatus(Resource): try: user_tools_collection.update_one( - {"_id": ObjectId(data["id"])}, + {"_id": ObjectId(data["id"]), "user": user}, {"$set": {"status": data["status"]}}, ) except Exception as err: @@ -2067,6 +2230,10 @@ class DeleteTool(Resource): ) @api.doc(description="Delete a tool by ID") def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id"] missing_fields = check_required_fields(data, required_fields) @@ -2074,7 +2241,9 @@ class DeleteTool(Resource): return missing_fields try: - result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])}) + result = user_tools_collection.delete_one( + {"_id": ObjectId(data["id"]), "user": user} + ) if result.deleted_count == 0: return {"success": False, "message": "Tool not found"}, 404 except Exception as err: @@ -2084,21 +2253,6 @@ class DeleteTool(Resource): return {"success": True}, 200 -def get_vector_store(source_id): - """ - Get the Vector Store - Args: - source_id (str): source id of the document - """ - - store = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, - source_id=source_id, - embeddings_key=os.getenv("EMBEDDINGS_KEY"), - ) - return store - - @user_ns.route("/api/get_chunks") class GetChunks(Resource): @api.doc( @@ -2106,6 +2260,10 @@ class GetChunks(Resource): params={"id": "The document ID"}, ) def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") doc_id = request.args.get("id") page = int(request.args.get("page", 1)) per_page = int(request.args.get("per_page", 10)) @@ -2113,6 +2271,12 @@ class GetChunks(Resource): if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid doc_id"}), 400) + doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) + if not doc: + return make_response( + jsonify({"error": "Document not found or access denied"}), 404 + ) + try: store = get_vector_store(doc_id) chunks = store.get_chunks() @@ -2157,6 +2321,10 @@ class AddChunk(Resource): description="Adds a new chunk to the document", ) def post(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id", "text"] missing_fields = check_required_fields(data, required_fields) @@ -2170,6 +2338,12 @@ class AddChunk(Resource): if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid doc_id"}), 400) + doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) + if not doc: + return make_response( + jsonify({"error": "Document not found or access denied"}), 404 + ) + try: store = get_vector_store(doc_id) chunk_id = store.add_chunk(text, metadata) @@ -2189,12 +2363,22 @@ class DeleteChunk(Resource): params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"}, ) def delete(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") doc_id = request.args.get("id") chunk_id = request.args.get("chunk_id") if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid doc_id"}), 400) + doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) + if not doc: + return make_response( + jsonify({"error": "Document not found or access denied"}), 404 + ) + try: store = get_vector_store(doc_id) deleted = store.delete_chunk(chunk_id) @@ -2236,6 +2420,10 @@ class UpdateChunk(Resource): description="Updates an existing chunk in the document.", ) def put(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") data = request.get_json() required_fields = ["id", "chunk_id"] missing_fields = check_required_fields(data, required_fields) @@ -2250,6 +2438,12 @@ class UpdateChunk(Resource): if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid doc_id"}), 400) + doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) + if not doc: + return make_response( + jsonify({"error": "Document not found or access denied"}), 404 + ) + try: store = get_vector_store(doc_id) chunks = store.get_chunks() diff --git a/application/app.py b/application/app.py index 4eb40331..7ca0ac2b 100644 --- a/application/app.py +++ b/application/app.py @@ -1,20 +1,28 @@ +import os import platform +import uuid import dotenv -from flask import Flask, redirect, request +from flask import Flask, jsonify, redirect, request +from jose import jwt + +from application.auth import handle_auth + from application.core.logging_config import setup_logging + setup_logging() -from application.api.answer.routes import answer # noqa: E402 -from application.api.internal.routes import internal # noqa: E402 -from application.api.user.routes import user # noqa: E402 -from application.celery_init import celery # noqa: E402 -from application.core.settings import settings # noqa: E402 -from application.extensions import api # noqa: E402 +from application.api.answer.routes import answer # noqa: E402 +from application.api.internal.routes import internal # noqa: E402 +from application.api.user.routes import user # noqa: E402 +from application.celery_init import celery # noqa: E402 +from application.core.settings import settings # noqa: E402 +from application.extensions import api # noqa: E402 if platform.system() == "Windows": import pathlib + pathlib.PosixPath = pathlib.WindowsPath dotenv.load_dotenv() @@ -32,6 +40,25 @@ app.config.update( celery.config_from_object("application.celeryconfig") api.init_app(app) +if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECRET_KEY: + key_file = ".jwt_secret_key" + try: + with open(key_file, "r") as f: + settings.JWT_SECRET_KEY = f.read().strip() + except FileNotFoundError: + new_key = os.urandom(32).hex() + with open(key_file, "w") as f: + f.write(new_key) + settings.JWT_SECRET_KEY = new_key + except Exception as e: + raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}") + +SIMPLE_JWT_TOKEN = None +if settings.AUTH_TYPE == "simple_jwt": + payload = {"sub": "local"} + SIMPLE_JWT_TOKEN = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm="HS256") + print(f"Generated Simple JWT Token: {SIMPLE_JWT_TOKEN}") + @app.route("/") def home(): @@ -41,11 +68,47 @@ def home(): return "Welcome to DocsGPT Backend!" +@app.route("/api/config") +def get_config(): + response = { + "auth_type": settings.AUTH_TYPE, + "requires_auth": settings.AUTH_TYPE in ["simple_jwt", "session_jwt"], + } + return jsonify(response) + + +@app.route("/api/generate_token") +def generate_token(): + if settings.AUTH_TYPE == "session_jwt": + new_user_id = str(uuid.uuid4()) + token = jwt.encode( + {"sub": new_user_id}, settings.JWT_SECRET_KEY, algorithm="HS256" + ) + return jsonify({"token": token}) + return jsonify({"error": "Token generation not allowed in current auth mode"}), 400 + + +@app.before_request +def authenticate_request(): + if request.method == "OPTIONS": + return "", 200 + + decoded_token = handle_auth(request) + if not decoded_token: + request.decoded_token = None + elif "error" in decoded_token: + return jsonify(decoded_token), 401 + else: + request.decoded_token = decoded_token + + @app.after_request def after_request(response): response.headers.add("Access-Control-Allow-Origin", "*") - response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization") - response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS") + response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization") + response.headers.add( + "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS" + ) return response diff --git a/application/auth.py b/application/auth.py new file mode 100644 index 00000000..78926c45 --- /dev/null +++ b/application/auth.py @@ -0,0 +1,28 @@ +from jose import jwt + +from application.core.settings import settings + + +def handle_auth(request, data={}): + if settings.AUTH_TYPE in ["simple_jwt", "session_jwt"]: + jwt_token = request.headers.get("Authorization") + if not jwt_token: + return None + + jwt_token = jwt_token.replace("Bearer ", "") + + try: + decoded_token = jwt.decode( + jwt_token, + settings.JWT_SECRET_KEY, + algorithms=["HS256"], + options={"verify_exp": False}, + ) + return decoded_token + except Exception as e: + return { + "message": f"Authentication error: {str(e)}", + "error": "invalid_token", + } + else: + return {"sub": "local"} diff --git a/application/core/settings.py b/application/core/settings.py index 04d7bbea..74bffe53 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -10,6 +10,7 @@ current_dir = os.path.dirname( class Settings(BaseSettings): + AUTH_TYPE: Optional[str] = None LLM_NAME: str = "docsgpt" MODEL_NAME: Optional[str] = ( None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo @@ -98,6 +99,8 @@ class Settings(BaseSettings): FLASK_DEBUG_MODE: bool = False + JWT_SECRET_KEY: str = "" + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/llm/base.py b/application/llm/base.py index e687e567..0fce208c 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -5,7 +5,8 @@ from application.usage import gen_token_usage, stream_token_usage class BaseLLM(ABC): - def __init__(self): + def __init__(self, decoded_token=None): + self.decoded_token = decoded_token self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} def _apply_decorator(self, method, decorators, *args, **kwargs): diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index 9f1305ba..3ed23854 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -9,6 +9,7 @@ from application.llm.premai import PremAILLM from application.llm.google_ai import GoogleLLM from application.llm.novita import NovitaLLM + class LLMCreator: llms = { "openai": OpenAILLM, @@ -21,12 +22,14 @@ class LLMCreator: "premai": PremAILLM, "groq": GroqLLM, "google": GoogleLLM, - "novita": NovitaLLM + "novita": NovitaLLM, } @classmethod - def create_llm(cls, type, api_key, user_api_key, *args, **kwargs): + def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs): llm_class = cls.llms.get(type.lower()) if not llm_class: raise ValueError(f"No LLM class found for type {type}") - return llm_class(api_key, user_api_key, *args, **kwargs) + return llm_class( + api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs + ) diff --git a/application/requirements.txt b/application/requirements.txt index 713ae2e3..5323fe85 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -69,6 +69,7 @@ pymongo==4.10.1 pypdf==5.2.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 +python-jose==3.4.0 python-pptx==1.0.2 qdrant-client==1.13.2 redis==5.2.1 diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 08b16bc0..ed490734 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -17,6 +17,7 @@ class BraveRetSearch(BaseRetriever): token_limit=150, gpt_model="docsgpt", user_api_key=None, + decoded_token=None, ): self.question = question self.source = source @@ -35,6 +36,7 @@ class BraveRetSearch(BaseRetriever): ) ) self.user_api_key = user_api_key + self.decoded_token = decoded_token def _get_data(self): if self.chunks == 0: @@ -81,7 +83,10 @@ class BraveRetSearch(BaseRetriever): messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + settings.LLM_NAME, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, + decoded_token=self.decoded_token, ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) @@ -100,5 +105,5 @@ class BraveRetSearch(BaseRetriever): "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, - "user_api_key": self.user_api_key + "user_api_key": self.user_api_key, } diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 03f17f44..08771337 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -17,6 +17,7 @@ class ClassicRAG(BaseRetriever): user_api_key=None, llm_name=settings.LLM_NAME, api_key=settings.API_KEY, + decoded_token=None, ): self.original_question = "" self.chat_history = chat_history if chat_history is not None else [] @@ -37,10 +38,14 @@ class ClassicRAG(BaseRetriever): self.llm_name = llm_name self.api_key = api_key self.llm = LLMCreator.create_llm( - self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key + self.llm_name, + api_key=self.api_key, + user_api_key=self.user_api_key, + decoded_token=decoded_token, ) self.question = self._rephrase_query() self.vectorstore = source["active_docs"] if "active_docs" in source else None + self.decoded_token = decoded_token def _rephrase_query(self): if ( diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index c6386410..9ce73995 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -17,6 +17,7 @@ class DuckDuckSearch(BaseRetriever): token_limit=150, gpt_model="docsgpt", user_api_key=None, + decoded_token=None, ): self.question = question self.source = source @@ -35,6 +36,7 @@ class DuckDuckSearch(BaseRetriever): ) ) self.user_api_key = user_api_key + self.decoded_token = decoded_token def _parse_lang_string(self, input_string): result = [] @@ -88,17 +90,20 @@ class DuckDuckSearch(BaseRetriever): for doc in docs: yield {"source": doc} - if len(self.chat_history) > 0: + if len(self.chat_history) > 0: for i in self.chat_history: - if "prompt" in i and "response" in i: - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append( - {"role": "assistant", "content": i["response"]} - ) + if "prompt" in i and "response" in i: + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append( + {"role": "assistant", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + settings.LLM_NAME, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, + decoded_token=self.decoded_token, ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) @@ -107,7 +112,7 @@ class DuckDuckSearch(BaseRetriever): def search(self): return self._get_data() - + def get_params(self): return { "question": self.question, @@ -117,5 +122,5 @@ class DuckDuckSearch(BaseRetriever): "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, - "user_api_key": self.user_api_key + "user_api_key": self.user_api_key, } diff --git a/application/usage.py b/application/usage.py index a18a3848..85328c1f 100644 --- a/application/usage.py +++ b/application/usage.py @@ -9,10 +9,15 @@ db = mongo["docsgpt"] usage_collection = db["token_usage"] -def update_token_usage(user_api_key, token_usage): +def update_token_usage(decoded_token, user_api_key, token_usage): if "pytest" in sys.modules: return + if decoded_token: + user_id = decoded_token["sub"] + else: + user_id = None usage_data = { + "user_id": user_id, "api_key": user_api_key, "prompt_tokens": token_usage["prompt_tokens"], "generated_tokens": token_usage["generated_tokens"], @@ -35,7 +40,7 @@ def gen_token_usage(func): self.token_usage["generated_tokens"] += num_tokens_from_object_or_list( result ) - update_token_usage(self.user_api_key, self.token_usage) + update_token_usage(self.decoded_token, self.user_api_key, self.token_usage) return result return wrapper @@ -54,6 +59,6 @@ def stream_token_usage(func): yield r for line in batch: self.token_usage["generated_tokens"] += num_tokens_from_string(line) - update_token_usage(self.user_api_key, self.token_usage) + update_token_usage(self.decoded_token, self.user_api_key, self.token_usage) return wrapper diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index ba0a4bd7..64c4c486 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,15 +1,30 @@ -import { Routes, Route } from 'react-router-dom'; -import Navigation from './Navigation'; -import Conversation from './conversation/Conversation'; -import About from './About'; -import PageNotFound from './PageNotFound'; -import { useMediaQuery } from './hooks'; -import { useState } from 'react'; -import Setting from './settings'; import './locale/i18n'; -import { Outlet } from 'react-router-dom'; + +import { useState } from 'react'; +import { Outlet, Route, Routes } from 'react-router-dom'; + +import About from './About'; +import Spinner from './components/Spinner'; +import Conversation from './conversation/Conversation'; import { SharedConversation } from './conversation/SharedConversation'; -import { useDarkTheme } from './hooks'; +import { useDarkTheme, useMediaQuery } from './hooks'; +import useTokenAuth from './hooks/useTokenAuth'; +import Navigation from './Navigation'; +import PageNotFound from './PageNotFound'; +import Setting from './settings'; + +function AuthWrapper({ children }: { children: React.ReactNode }) { + const { isAuthLoading } = useTokenAuth(); + + if (isAuthLoading) { + return ( +
+ +
+ ); + } + return <>{children}; +} function MainLayout() { const { isMobile } = useMediaQuery(); @@ -39,7 +54,13 @@ export default function App() { return (
- }> + + + + } + > } /> } /> } /> diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx index 49795b41..9e1889aa 100644 --- a/frontend/src/Navigation.tsx +++ b/frontend/src/Navigation.tsx @@ -2,28 +2,35 @@ import { useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useDispatch, useSelector } from 'react-redux'; import { NavLink, useNavigate } from 'react-router-dom'; + import conversationService from './api/services/conversationService'; import userService from './api/services/userService'; import Add from './assets/add.svg'; -import openNewChat from './assets/openNewChat.svg'; -import Hamburger from './assets/hamburger.svg'; import DocsGPT3 from './assets/cute_docsgpt3.svg'; import Discord from './assets/discord.svg'; import Expand from './assets/expand.svg'; import Github from './assets/github.svg'; +import Hamburger from './assets/hamburger.svg'; +import openNewChat from './assets/openNewChat.svg'; import SettingGear from './assets/settingGear.svg'; +import SpinnerDark from './assets/spinner-dark.svg'; +import Spinner from './assets/spinner.svg'; import Twitter from './assets/TwitterX.svg'; import UploadIcon from './assets/upload.svg'; +import Help from './components/Help'; import SourceDropdown from './components/SourceDropdown'; import { + handleAbort, + selectQueries, setConversation, updateConversationId, - handleAbort, } from './conversation/conversationSlice'; import ConversationTile from './conversation/ConversationTile'; import { useDarkTheme, useMediaQuery } from './hooks'; import useDefaultDocument from './hooks/useDefaultDocument'; +import useTokenAuth from './hooks/useTokenAuth'; import DeleteConvModal from './modals/DeleteConvModal'; +import JWTModal from './modals/JWTModal'; import { ActiveState, Doc } from './models/misc'; import { getConversations, getDocs } from './preferences/preferenceApi'; import { @@ -31,20 +38,17 @@ import { selectConversationId, selectConversations, selectModalStateDeleteConv, + selectPaginatedDocuments, selectSelectedDocs, selectSourceDocs, - selectPaginatedDocuments, + selectToken, setConversations, setModalStateDeleteConv, + setPaginatedDocuments, setSelectedDocs, setSourceDocs, - setPaginatedDocuments, } from './preferences/preferenceSlice'; -import Spinner from './assets/spinner.svg'; -import SpinnerDark from './assets/spinner-dark.svg'; -import { selectQueries } from './conversation/conversationSlice'; import Upload from './upload/Upload'; -import Help from './components/Help'; interface NavigationProps { navOpen: boolean; @@ -53,6 +57,7 @@ interface NavigationProps { export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const dispatch = useDispatch(); + const token = useSelector(selectToken); const queries = useSelector(selectQueries); const docs = useSelector(selectSourceDocs); const selectedDocs = useSelector(selectSelectedDocs); @@ -68,6 +73,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const { t } = useTranslation(); const isApiKeySet = useSelector(selectApiKeyStatus); + const { showTokenModal, handleTokenSubmit } = useTokenAuth(); + const [uploadModalState, setUploadModalState] = useState('INACTIVE'); @@ -86,7 +93,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { async function fetchConversations() { dispatch(setConversations({ ...conversations, loading: true })); - return await getConversations() + return await getConversations(token) .then((fetchedConversations) => { dispatch(setConversations(fetchedConversations)); }) @@ -99,7 +106,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleDeleteAllConversations = () => { setIsDeletingConversation(true); conversationService - .deleteAll() + .deleteAll(token) .then(() => { fetchConversations(); }) @@ -109,7 +116,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleDeleteConversation = (id: string) => { setIsDeletingConversation(true); conversationService - .delete(id, {}) + .delete(id, {}, token) .then(() => { fetchConversations(); resetConversation(); @@ -119,9 +126,9 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleDeleteClick = (doc: Doc) => { userService - .deletePath(doc.id ?? '') + .deletePath(doc.id ?? '', token) .then(() => { - return getDocs(); + return getDocs(token); }) .then((updatedDocs) => { dispatch(setSourceDocs(updatedDocs)); @@ -145,7 +152,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleConversationClick = (index: string) => { conversationService - .getConversation(index) + .getConversation(index, token) .then((response) => response.json()) .then((data) => { navigate('/'); @@ -177,7 +184,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { id: string; }) { await conversationService - .update(updatedConversation) + .update(updatedConversation, token) .then((response) => response.json()) .then((data) => { if (data) { @@ -197,8 +204,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { useEffect(() => { setNavOpen(!isMobile); }, [isMobile]); - useDefaultDocument(); + useDefaultDocument(); return ( <> {!navOpen && ( @@ -472,6 +479,10 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { close={() => setUploadModalState('INACTIVE')} > )} + ); } diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 21699721..3db613fc 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -4,14 +4,24 @@ const defaultHeaders = { 'Content-Type': 'application/json', }; +const getHeaders = (token: string | null, customHeaders = {}): HeadersInit => { + return { + ...defaultHeaders, + ...(token ? { Authorization: `Bearer ${token}` } : {}), + ...customHeaders, + }; +}; + const apiClient = { - get: (url: string, headers = {}, signal?: AbortSignal): Promise => + get: ( + url: string, + token: string | null, + headers = {}, + signal?: AbortSignal, + ): Promise => fetch(`${baseURL}${url}`, { method: 'GET', - headers: { - ...defaultHeaders, - ...headers, - }, + headers: getHeaders(token, headers), signal, }).then((response) => { return response; @@ -20,15 +30,13 @@ const apiClient = { post: ( url: string, data: any, + token: string | null, headers = {}, signal?: AbortSignal, ): Promise => fetch(`${baseURL}${url}`, { method: 'POST', - headers: { - ...defaultHeaders, - ...headers, - }, + headers: getHeaders(token, headers), body: JSON.stringify(data), signal, }).then((response) => { @@ -38,28 +46,28 @@ const apiClient = { put: ( url: string, data: any, + token: string | null, headers = {}, signal?: AbortSignal, ): Promise => fetch(`${baseURL}${url}`, { method: 'PUT', - headers: { - ...defaultHeaders, - ...headers, - }, + headers: getHeaders(token, headers), body: JSON.stringify(data), signal, }).then((response) => { return response; }), - delete: (url: string, headers = {}, signal?: AbortSignal): Promise => + delete: ( + url: string, + token: string | null, + headers = {}, + signal?: AbortSignal, + ): Promise => fetch(`${baseURL}${url}`, { method: 'DELETE', - headers: { - ...defaultHeaders, - ...headers, - }, + headers: getHeaders(token, headers), signal, }).then((response) => { return response; diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 9bf659de..0d574f89 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -1,5 +1,7 @@ const endpoints = { USER: { + CONFIG: '/api/config', + NEW_TOKEN: '/api/generate_token', DOCS: '/api/sources', DOCS_CHECK: '/api/docs_check', DOCS_PAGINATED: '/api/sources/paginated', diff --git a/frontend/src/api/services/conversationService.ts b/frontend/src/api/services/conversationService.ts index aaf703de..853a6863 100644 --- a/frontend/src/api/services/conversationService.ts +++ b/frontend/src/api/services/conversationService.ts @@ -2,31 +2,58 @@ import apiClient from '../client'; import endpoints from '../endpoints'; const conversationService = { - answer: (data: any, signal: AbortSignal): Promise => - apiClient.post(endpoints.CONVERSATION.ANSWER, data, {}, signal), - answerStream: (data: any, signal: AbortSignal): Promise => - apiClient.post(endpoints.CONVERSATION.ANSWER_STREAMING, data, {}, signal), - search: (data: any): Promise => - apiClient.post(endpoints.CONVERSATION.SEARCH, data), - feedback: (data: any): Promise => - apiClient.post(endpoints.CONVERSATION.FEEDBACK, data), - getConversation: (id: string): Promise => - apiClient.get(endpoints.CONVERSATION.CONVERSATION(id)), - getConversations: (): Promise => - apiClient.get(endpoints.CONVERSATION.CONVERSATIONS), - shareConversation: (isPromptable: boolean, data: any): Promise => + answer: ( + data: any, + token: string | null, + signal: AbortSignal, + ): Promise => + apiClient.post(endpoints.CONVERSATION.ANSWER, data, token, {}, signal), + answerStream: ( + data: any, + token: string | null, + signal: AbortSignal, + ): Promise => + apiClient.post( + endpoints.CONVERSATION.ANSWER_STREAMING, + data, + token, + {}, + signal, + ), + search: (data: any, token: string | null): Promise => + apiClient.post(endpoints.CONVERSATION.SEARCH, data, token, {}), + feedback: (data: any, token: string | null): Promise => + apiClient.post(endpoints.CONVERSATION.FEEDBACK, data, token, {}), + getConversation: (id: string, token: string | null): Promise => + apiClient.get(endpoints.CONVERSATION.CONVERSATION(id), token, {}), + getConversations: (token: string | null): Promise => + apiClient.get(endpoints.CONVERSATION.CONVERSATIONS, token, {}), + shareConversation: ( + isPromptable: boolean, + data: any, + token: string | null, + ): Promise => apiClient.post( endpoints.CONVERSATION.SHARE_CONVERSATION(isPromptable), data, + token, + {}, ), - getSharedConversation: (identifier: string): Promise => - apiClient.get(endpoints.CONVERSATION.SHARED_CONVERSATION(identifier)), - delete: (id: string, data: any): Promise => - apiClient.post(endpoints.CONVERSATION.DELETE(id), data), - deleteAll: (): Promise => - apiClient.get(endpoints.CONVERSATION.DELETE_ALL), - update: (data: any): Promise => - apiClient.post(endpoints.CONVERSATION.UPDATE, data), + getSharedConversation: ( + identifier: string, + token: string | null, + ): Promise => + apiClient.get( + endpoints.CONVERSATION.SHARED_CONVERSATION(identifier), + token, + {}, + ), + delete: (id: string, data: any, token: string | null): Promise => + apiClient.post(endpoints.CONVERSATION.DELETE(id), data, token, {}), + deleteAll: (token: string | null): Promise => + apiClient.get(endpoints.CONVERSATION.DELETE_ALL, token, {}), + update: (data: any, token: string | null): Promise => + apiClient.post(endpoints.CONVERSATION.UPDATE, data, token, {}), }; export default conversationService; diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index e7f367f1..13083677 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -2,63 +2,74 @@ import apiClient from '../client'; import endpoints from '../endpoints'; const userService = { - getDocs: (): Promise => apiClient.get(`${endpoints.USER.DOCS}`), - getDocsWithPagination: (query: string): Promise => - apiClient.get(`${endpoints.USER.DOCS_PAGINATED}?${query}`), - checkDocs: (data: any): Promise => - apiClient.post(endpoints.USER.DOCS_CHECK, data), - getAPIKeys: (): Promise => apiClient.get(endpoints.USER.API_KEYS), - createAPIKey: (data: any): Promise => - apiClient.post(endpoints.USER.CREATE_API_KEY, data), - deleteAPIKey: (data: any): Promise => - apiClient.post(endpoints.USER.DELETE_API_KEY, data), - getPrompts: (): Promise => apiClient.get(endpoints.USER.PROMPTS), - createPrompt: (data: any): Promise => - apiClient.post(endpoints.USER.CREATE_PROMPT, data), - deletePrompt: (data: any): Promise => - apiClient.post(endpoints.USER.DELETE_PROMPT, data), - updatePrompt: (data: any): Promise => - apiClient.post(endpoints.USER.UPDATE_PROMPT, data), - getSinglePrompt: (id: string): Promise => - apiClient.get(endpoints.USER.SINGLE_PROMPT(id)), - deletePath: (docPath: string): Promise => - apiClient.get(endpoints.USER.DELETE_PATH(docPath)), - getTaskStatus: (task_id: string): Promise => - apiClient.get(endpoints.USER.TASK_STATUS(task_id)), - getMessageAnalytics: (data: any): Promise => - apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data), - getTokenAnalytics: (data: any): Promise => - apiClient.post(endpoints.USER.TOKEN_ANALYTICS, data), - getFeedbackAnalytics: (data: any): Promise => - apiClient.post(endpoints.USER.FEEDBACK_ANALYTICS, data), - getLogs: (data: any): Promise => - apiClient.post(endpoints.USER.LOGS, data), - manageSync: (data: any): Promise => - apiClient.post(endpoints.USER.MANAGE_SYNC, data), - getAvailableTools: (): Promise => - apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS), - getUserTools: (): Promise => - apiClient.get(endpoints.USER.GET_USER_TOOLS), - createTool: (data: any): Promise => - apiClient.post(endpoints.USER.CREATE_TOOL, data), - updateToolStatus: (data: any): Promise => - apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data), - updateTool: (data: any): Promise => - apiClient.post(endpoints.USER.UPDATE_TOOL, data), - deleteTool: (data: any): Promise => - apiClient.post(endpoints.USER.DELETE_TOOL, data), + getConfig: (): Promise => apiClient.get(endpoints.USER.CONFIG, null), + getNewToken: (): Promise => + apiClient.get(endpoints.USER.NEW_TOKEN, null), + getDocs: (token: string | null): Promise => + apiClient.get(`${endpoints.USER.DOCS}`, token), + getDocsWithPagination: (query: string, token: string | null): Promise => + apiClient.get(`${endpoints.USER.DOCS_PAGINATED}?${query}`, token), + checkDocs: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.DOCS_CHECK, data, token), + getAPIKeys: (token: string | null): Promise => + apiClient.get(endpoints.USER.API_KEYS, token), + createAPIKey: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.CREATE_API_KEY, data, token), + deleteAPIKey: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.DELETE_API_KEY, data, token), + getPrompts: (token: string | null): Promise => + apiClient.get(endpoints.USER.PROMPTS, token), + createPrompt: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.CREATE_PROMPT, data, token), + deletePrompt: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.DELETE_PROMPT, data, token), + updatePrompt: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.UPDATE_PROMPT, data, token), + getSinglePrompt: (id: string, token: string | null): Promise => + apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token), + deletePath: (docPath: string, token: string | null): Promise => + apiClient.get(endpoints.USER.DELETE_PATH(docPath), token), + getTaskStatus: (task_id: string, token: string | null): Promise => + apiClient.get(endpoints.USER.TASK_STATUS(task_id), token), + getMessageAnalytics: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token), + getTokenAnalytics: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.TOKEN_ANALYTICS, data, token), + getFeedbackAnalytics: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.FEEDBACK_ANALYTICS, data, token), + getLogs: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.LOGS, data, token), + manageSync: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.MANAGE_SYNC, data, token), + getAvailableTools: (token: string | null): Promise => + apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS, token), + getUserTools: (token: string | null): Promise => + apiClient.get(endpoints.USER.GET_USER_TOOLS, token), + createTool: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.CREATE_TOOL, data, token), + updateToolStatus: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data, token), + updateTool: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL, data, token), + deleteTool: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.DELETE_TOOL, data, token), getDocumentChunks: ( docId: string, page: number, perPage: number, + token: string | null, ): Promise => - apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage)), - addChunk: (data: any): Promise => - apiClient.post(endpoints.USER.ADD_CHUNK, data), - deleteChunk: (docId: string, chunkId: string): Promise => - apiClient.delete(endpoints.USER.DELETE_CHUNK(docId, chunkId)), - updateChunk: (data: any): Promise => - apiClient.put(endpoints.USER.UPDATE_CHUNK, data), + apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage), token), + addChunk: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.ADD_CHUNK, data, token), + deleteChunk: ( + docId: string, + chunkId: string, + token: string | null, + ): Promise => + apiClient.delete(endpoints.USER.DELETE_CHUNK(docId, chunkId), token), + updateChunk: (data: any, token: string | null): Promise => + apiClient.put(endpoints.USER.UPDATE_CHUNK, data, token), }; export default userService; diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx index 9f54ddbd..2dd9e773 100644 --- a/frontend/src/conversation/Conversation.tsx +++ b/frontend/src/conversation/Conversation.tsx @@ -7,7 +7,10 @@ import newChatIcon from '../assets/openNewChat.svg'; import ShareIcon from '../assets/share.svg'; import { useMediaQuery } from '../hooks'; import { ShareConversationModal } from '../modals/ShareConversationModal'; -import { selectConversationId } from '../preferences/preferenceSlice'; +import { + selectConversationId, + selectToken, +} from '../preferences/preferenceSlice'; import { AppDispatch } from '../store'; import { handleSendFeedback } from './conversationHandlers'; import { FEEDBACK, Query } from './conversationModels'; @@ -27,6 +30,7 @@ import ConversationMessages from './ConversationMessages'; import MessageInput from '../components/MessageInput'; export default function Conversation() { + const token = useSelector(selectToken); const queries = useSelector(selectQueries); const status = useSelector(selectStatus); const conversationId = useSelector(selectConversationId); @@ -118,6 +122,7 @@ export default function Conversation() { feedback, conversationId as string, index, + token, ).catch(() => handleSendFeedback( query.prompt, @@ -125,6 +130,7 @@ export default function Conversation() { feedback, conversationId as string, index, + token, ).catch(() => dispatch(updateQuery({ index, query: { feedback: prevFeedback } })), ), diff --git a/frontend/src/conversation/SharedConversation.tsx b/frontend/src/conversation/SharedConversation.tsx index 993556e4..be822805 100644 --- a/frontend/src/conversation/SharedConversation.tsx +++ b/frontend/src/conversation/SharedConversation.tsx @@ -1,35 +1,34 @@ import { useEffect, useState } from 'react'; +import { Helmet } from 'react-helmet'; import { useTranslation } from 'react-i18next'; +import { useDispatch, useSelector } from 'react-redux'; import { useNavigate, useParams } from 'react-router-dom'; -import ConversationMessages from './ConversationMessages'; -import MessageInput from '../components/MessageInput'; + import conversationService from '../api/services/conversationService'; +import MessageInput from '../components/MessageInput'; +import { selectToken } from '../preferences/preferenceSlice'; +import { AppDispatch } from '../store'; +import { formatDate } from '../utils/dateTimeUtils'; +import ConversationMessages from './ConversationMessages'; import { - selectClientAPIKey, - setClientApiKey, - updateQuery, addQuery, fetchSharedAnswer, - selectStatus, -} from './sharedConversationSlice'; -import { setIdentifier, setFetchedData } from './sharedConversationSlice'; - -import { useDispatch } from 'react-redux'; -import { AppDispatch } from '../store'; - -import { + selectClientAPIKey, selectDate, - selectTitle, selectQueries, + selectStatus, + selectTitle, + setClientApiKey, + setFetchedData, + setIdentifier, + updateQuery, } from './sharedConversationSlice'; -import { useSelector } from 'react-redux'; -import { Helmet } from 'react-helmet'; -import { formatDate } from '../utils/dateTimeUtils'; export const SharedConversation = () => { const navigate = useNavigate(); const { identifier } = useParams(); //identifier is a uuid, not conversationId + const token = useSelector(selectToken); const queries = useSelector(selectQueries); const title = useSelector(selectTitle); const date = useSelector(selectDate); @@ -56,7 +55,7 @@ export const SharedConversation = () => { const fetchQueries = () => { identifier && conversationService - .getSharedConversation(identifier || '') + .getSharedConversation(identifier || '', token) .then((res) => { if (res.status === 404 || res.status === 400) navigate('/pagenotfound'); diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index 0b54a366..88771fc5 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -6,6 +6,7 @@ import { ToolCallsType } from './types'; export function handleFetchAnswer( question: string, signal: AbortSignal, + token: string | null, selectedDocs: Doc | null, history: Array = [], conversationId: string | null, @@ -52,7 +53,7 @@ export function handleFetchAnswer( } payload.retriever = selectedDocs?.retriever as string; return conversationService - .answer(payload, signal) + .answer(payload, token, signal) .then((response) => { if (response.ok) { return response.json(); @@ -76,6 +77,7 @@ export function handleFetchAnswer( export function handleFetchAnswerSteaming( question: string, signal: AbortSignal, + token: string | null, selectedDocs: Doc | null, history: Array = [], conversationId: string | null, @@ -109,7 +111,7 @@ export function handleFetchAnswerSteaming( return new Promise((resolve, reject) => { conversationService - .answerStream(payload, signal) + .answerStream(payload, token, signal) .then((response) => { if (!response.body) throw Error('No response body'); @@ -160,6 +162,7 @@ export function handleFetchAnswerSteaming( export function handleSearch( question: string, + token: string | null, selectedDocs: Doc | null, conversation_id: string | null, history: Array = [], @@ -185,7 +188,7 @@ export function handleSearch( payload.active_docs = selectedDocs.id as string; payload.retriever = selectedDocs?.retriever as string; return conversationService - .search(payload) + .search(payload, token) .then((response) => response.json()) .then((data) => { return data; @@ -206,11 +209,14 @@ export function handleSearchViaApiKey( }; }); return conversationService - .search({ - question: question, - history: JSON.stringify(history), - api_key: api_key, - }) + .search( + { + question: question, + history: JSON.stringify(history), + api_key: api_key, + }, + null, + ) .then((response) => response.json()) .then((data) => { return data; @@ -224,15 +230,19 @@ export function handleSendFeedback( feedback: FEEDBACK, conversation_id: string, prompt_index: number, + token: string | null, ) { return conversationService - .feedback({ - question: prompt, - answer: response, - feedback: feedback, - conversation_id: conversation_id, - question_index: prompt_index, - }) + .feedback( + { + question: prompt, + answer: response, + feedback: feedback, + conversation_id: conversation_id, + question_index: prompt_index, + }, + token, + ) .then((response) => { if (response.ok) { return Promise.resolve(); @@ -265,7 +275,7 @@ export function handleFetchSharedAnswerStreaming( //for shared conversations save_conversation: false, }; conversationService - .answerStream(payload, signal) + .answerStream(payload, null, signal) .then((response) => { if (!response.body) throw Error('No response body'); @@ -339,6 +349,7 @@ export function handleFetchSharedAnswer( question: question, api_key: apiKey, }, + null, signal, ) .then((response) => { diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index f00eb546..7cd14d5e 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -42,6 +42,7 @@ export const fetchAnswer = createAsyncThunk< await handleFetchAnswerSteaming( question, signal, + state.preference.token, state.preference.selectedDocs!, state.conversation.queries, state.conversation.conversationId, @@ -53,7 +54,7 @@ export const fetchAnswer = createAsyncThunk< if (data.type === 'end') { dispatch(conversationSlice.actions.setStatus('idle')); - getConversations() + getConversations(state.preference.token) .then((fetchedConversations) => { dispatch(setConversations(fetchedConversations)); }) @@ -114,6 +115,7 @@ export const fetchAnswer = createAsyncThunk< const answer = await handleFetchAnswer( question, signal, + state.preference.token, state.preference.selectedDocs!, state.conversation.queries, state.conversation.conversationId, @@ -150,7 +152,7 @@ export const fetchAnswer = createAsyncThunk< }), ); dispatch(conversationSlice.actions.setStatus('idle')); - getConversations() + getConversations(state.preference.token) .then((fetchedConversations) => { dispatch(setConversations(fetchedConversations)); }) diff --git a/frontend/src/hooks/useDefaultDocument.ts b/frontend/src/hooks/useDefaultDocument.ts index 7f4b9812..a2642dc5 100644 --- a/frontend/src/hooks/useDefaultDocument.ts +++ b/frontend/src/hooks/useDefaultDocument.ts @@ -1,20 +1,22 @@ import React from 'react'; import { useDispatch, useSelector } from 'react-redux'; -import { getDocs } from '../preferences/preferenceApi'; import { Doc } from '../models/misc'; +import { getDocs } from '../preferences/preferenceApi'; import { selectSelectedDocs, + selectToken, setSelectedDocs, setSourceDocs, } from '../preferences/preferenceSlice'; export default function useDefaultDocument() { const dispatch = useDispatch(); + const token = useSelector(selectToken); const selectedDoc = useSelector(selectSelectedDocs); const fetchDocs = () => { - getDocs().then((data) => { + getDocs(token).then((data) => { dispatch(setSourceDocs(data)); if (!selectedDoc) Array.isArray(data) && diff --git a/frontend/src/hooks/useTokenAuth.ts b/frontend/src/hooks/useTokenAuth.ts new file mode 100644 index 00000000..8f408600 --- /dev/null +++ b/frontend/src/hooks/useTokenAuth.ts @@ -0,0 +1,55 @@ +import { useEffect, useRef, useState } from 'react'; +import { useDispatch, useSelector } from 'react-redux'; + +import userService from '../api/services/userService'; +import { selectToken, setToken } from '../preferences/preferenceSlice'; + +export default function useAuth() { + const dispatch = useDispatch(); + const token = useSelector(selectToken); + const [authType, setAuthType] = useState(null); + const [showTokenModal, setShowTokenModal] = useState(false); + const [isAuthLoading, setIsAuthLoading] = useState(true); + const isGeneratingToken = useRef(false); + + const generateNewToken = async () => { + if (isGeneratingToken.current) return; + isGeneratingToken.current = true; + const response = await userService.getNewToken(); + const { token: newToken } = await response.json(); + localStorage.setItem('authToken', newToken); + dispatch(setToken(newToken)); + setIsAuthLoading(false); + return newToken; + }; + + useEffect(() => { + const initializeAuth = async () => { + try { + const configRes = await userService.getConfig(); + const config = await configRes.json(); + setAuthType(config.auth_type); + + if (config.auth_type === 'session_jwt' && !token) { + await generateNewToken(); + } else if (config.auth_type === 'simple_jwt' && !token) { + setShowTokenModal(true); + setIsAuthLoading(false); + } else { + setIsAuthLoading(false); + } + } catch (error) { + console.error('Auth initialization failed:', error); + setIsAuthLoading(false); + } + }; + initializeAuth(); + }, []); + + const handleTokenSubmit = (enteredToken: string) => { + localStorage.setItem('authToken', enteredToken); + dispatch(setToken(enteredToken)); + setShowTokenModal(false); + }; + return { authType, showTokenModal, isAuthLoading, token, handleTokenSubmit }; +} diff --git a/frontend/src/modals/AddToolModal.tsx b/frontend/src/modals/AddToolModal.tsx index 42b55d69..9885edab 100644 --- a/frontend/src/modals/AddToolModal.tsx +++ b/frontend/src/modals/AddToolModal.tsx @@ -1,12 +1,14 @@ import React, { useRef } from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; import userService from '../api/services/userService'; +import Spinner from '../components/Spinner'; import { useOutsideAlerter } from '../hooks'; import { ActiveState } from '../models/misc'; +import { selectToken } from '../preferences/preferenceSlice'; import ConfigToolModal from './ConfigToolModal'; import { AvailableToolType } from './types'; -import Spinner from '../components/Spinner'; import WrapperComponent from './WrapperModal'; export default function AddToolModal({ @@ -23,6 +25,7 @@ export default function AddToolModal({ onToolAdded: (toolId: string) => void; }) { const { t } = useTranslation(); + const token = useSelector(selectToken); const modalRef = useRef(null); const [availableTools, setAvailableTools] = React.useState< AvailableToolType[] @@ -42,7 +45,7 @@ export default function AddToolModal({ const getAvailableTools = () => { setLoading(true); userService - .getAvailableTools() + .getAvailableTools(token) .then((res) => { return res.json(); }) @@ -55,14 +58,17 @@ export default function AddToolModal({ const handleAddTool = (tool: AvailableToolType) => { if (Object.keys(tool.configRequirements).length === 0) { userService - .createTool({ - name: tool.name, - displayName: tool.displayName, - description: tool.description, - config: {}, - actions: tool.actions, - status: true, - }) + .createTool( + { + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: {}, + actions: tool.actions, + status: true, + }, + token, + ) .then((res) => { if (res.status === 200) { return res.json(); diff --git a/frontend/src/modals/ConfigToolModal.tsx b/frontend/src/modals/ConfigToolModal.tsx index 05517c51..e631419a 100644 --- a/frontend/src/modals/ConfigToolModal.tsx +++ b/frontend/src/modals/ConfigToolModal.tsx @@ -1,11 +1,13 @@ import React from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; -import WrapperModal from './WrapperModal'; +import userService from '../api/services/userService'; import Input from '../components/Input'; import { ActiveState } from '../models/misc'; +import { selectToken } from '../preferences/preferenceSlice'; import { AvailableToolType } from './types'; -import userService from '../api/services/userService'; +import WrapperModal from './WrapperModal'; interface ConfigToolModalProps { modalState: ActiveState; @@ -21,18 +23,22 @@ export default function ConfigToolModal({ getUserTools, }: ConfigToolModalProps) { const { t } = useTranslation(); + const token = useSelector(selectToken); const [authKey, setAuthKey] = React.useState(''); const handleAddTool = (tool: AvailableToolType) => { userService - .createTool({ - name: tool.name, - displayName: tool.displayName, - description: tool.description, - config: { token: authKey }, - actions: tool.actions, - status: true, - }) + .createTool( + { + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: { token: authKey }, + actions: tool.actions, + status: true, + }, + token, + ) .then(() => { setModalState('INACTIVE'); getUserTools(); diff --git a/frontend/src/modals/CreateAPIKeyModal.tsx b/frontend/src/modals/CreateAPIKeyModal.tsx index a35efd60..79e2120d 100644 --- a/frontend/src/modals/CreateAPIKeyModal.tsx +++ b/frontend/src/modals/CreateAPIKeyModal.tsx @@ -6,7 +6,7 @@ import userService from '../api/services/userService'; import Dropdown from '../components/Dropdown'; import Input from '../components/Input'; import { CreateAPIKeyModalProps, Doc } from '../models/misc'; -import { selectSourceDocs } from '../preferences/preferenceSlice'; +import { selectSourceDocs, selectToken } from '../preferences/preferenceSlice'; import WrapperModal from './WrapperModal'; const embeddingsName = @@ -18,6 +18,7 @@ export default function CreateAPIKeyModal({ createAPIKey, }: CreateAPIKeyModalProps) { const { t } = useTranslation(); + const token = useSelector(selectToken); const docs = useSelector(selectSourceDocs); const [APIKeyName, setAPIKeyName] = React.useState(''); @@ -60,7 +61,7 @@ export default function CreateAPIKeyModal({ React.useEffect(() => { const handleFetchPrompts = async () => { try { - const response = await userService.getPrompts(); + const response = await userService.getPrompts(token); if (!response.ok) { throw new Error('Failed to fetch prompts'); } diff --git a/frontend/src/modals/JWTModal.tsx b/frontend/src/modals/JWTModal.tsx new file mode 100644 index 00000000..5f25b217 --- /dev/null +++ b/frontend/src/modals/JWTModal.tsx @@ -0,0 +1,47 @@ +import React, { useState } from 'react'; +import { useDispatch } from 'react-redux'; + +import Input from '../components/Input'; +import { ActiveState } from '../models/misc'; +import WrapperModal from './WrapperModal'; + +type JWTModalProps = { + modalState: ActiveState; + handleTokenSubmit: (enteredToken: string) => void; +}; + +export default function JWTModal({ + modalState, + handleTokenSubmit, +}: JWTModalProps) { + const [jwtToken, setJwtToken] = useState(''); + + if (modalState !== 'ACTIVE') return null; + + return ( + {}}> +
+ + Add JWT Token + +
+
+ setJwtToken(e.target.value)} + borderVariant="thin" + /> +
+ +
+ ); +} diff --git a/frontend/src/modals/ShareConversationModal.tsx b/frontend/src/modals/ShareConversationModal.tsx index 3fd444ad..73bb5acd 100644 --- a/frontend/src/modals/ShareConversationModal.tsx +++ b/frontend/src/modals/ShareConversationModal.tsx @@ -1,16 +1,21 @@ import { useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useSelector } from 'react-redux'; -import { - selectSourceDocs, - selectSelectedDocs, - selectChunks, - selectPrompt, -} from '../preferences/preferenceSlice'; + +import conversationService from '../api/services/conversationService'; +import Spinner from '../assets/spinner.svg'; import Dropdown from '../components/Dropdown'; import ToggleSwitch from '../components/ToggleSwitch'; import { Doc } from '../models/misc'; -import Spinner from '../assets/spinner.svg'; +import { + selectChunks, + selectPrompt, + selectSelectedDocs, + selectSourceDocs, + selectToken, +} from '../preferences/preferenceSlice'; +import WrapperModal from './WrapperModal'; + const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com'; const embeddingsName = import.meta.env.VITE_EMBEDDINGS_NAME || @@ -18,9 +23,6 @@ const embeddingsName = type StatusType = 'loading' | 'idle' | 'fetched' | 'failed'; -import conversationService from '../api/services/conversationService'; -import WrapperModal from './WrapperModal'; - export const ShareConversationModal = ({ close, conversationId, @@ -29,6 +31,7 @@ export const ShareConversationModal = ({ conversationId: string; }) => { const { t } = useTranslation(); + const token = useSelector(selectToken); const domain = window.location.origin; @@ -86,7 +89,7 @@ export const ShareConversationModal = ({ sourcePath && (payload.source = sourcePath.value); } conversationService - .shareConversation(isPromptable, payload) + .shareConversation(isPromptable, payload, token) .then((res) => { return res.json(); }) diff --git a/frontend/src/preferences/preferenceApi.ts b/frontend/src/preferences/preferenceApi.ts index 8d21bdcd..d52580e0 100644 --- a/frontend/src/preferences/preferenceApi.ts +++ b/frontend/src/preferences/preferenceApi.ts @@ -3,9 +3,9 @@ import userService from '../api/services/userService'; import { Doc, GetDocsResponse } from '../models/misc'; //Fetches all JSON objects from the source. We only use the objects with the "model" property in SelectDocsModal.tsx. Hopefully can clean up the source file later. -export async function getDocs(): Promise { +export async function getDocs(token: string | null): Promise { try { - const response = await userService.getDocs(); + const response = await userService.getDocs(token); const data = await response.json(); const docs: Doc[] = []; @@ -26,10 +26,11 @@ export async function getDocsWithPagination( pageNumber = 1, rowsPerPage = 10, searchTerm = '', + token: string | null, ): Promise { try { const query = `sort=${sort}&order=${order}&page=${pageNumber}&rows=${rowsPerPage}&search=${searchTerm}`; - const response = await userService.getDocsWithPagination(query); + const response = await userService.getDocsWithPagination(query, token); const data = await response.json(); const docs: Doc[] = []; Array.isArray(data.paginated) && @@ -48,12 +49,12 @@ export async function getDocsWithPagination( } } -export async function getConversations(): Promise<{ +export async function getConversations(token: string | null): Promise<{ data: { name: string; id: string }[] | null; loading: boolean; }> { try { - const response = await conversationService.getConversations(); + const response = await conversationService.getConversations(token); const data = await response.json(); const conversations: { name: string; id: string }[] = []; @@ -100,8 +101,11 @@ export function setLocalRecentDocs(doc: Doc | null): void { docPath = 'local' + '/' + doc.name + '/'; } userService - .checkDocs({ - docs: docPath, - }) + .checkDocs( + { + docs: docPath, + }, + null, + ) .then((response) => response.json()); } diff --git a/frontend/src/preferences/preferenceSlice.ts b/frontend/src/preferences/preferenceSlice.ts index 8b3064d5..4bca1a37 100644 --- a/frontend/src/preferences/preferenceSlice.ts +++ b/frontend/src/preferences/preferenceSlice.ts @@ -19,6 +19,7 @@ export interface Preference { data: { name: string; id: string }[] | null; loading: boolean; }; + token: string | null; modalState: ActiveState; paginatedDocuments: Doc[] | null; } @@ -42,6 +43,7 @@ const initialState: Preference = { data: null, loading: false, }, + token: localStorage.getItem('authToken') || null, modalState: 'INACTIVE', paginatedDocuments: null, }; @@ -65,6 +67,9 @@ export const prefSlice = createSlice({ setConversations: (state, action) => { state.conversations = action.payload; }, + setToken: (state, action) => { + state.token = action.payload; + }, setPrompt: (state, action) => { state.prompt = action.payload; }, @@ -85,6 +90,7 @@ export const { setSelectedDocs, setSourceDocs, setConversations, + setToken, setPrompt, setChunks, setTokenLimit, @@ -157,6 +163,7 @@ export const selectConversations = (state: RootState) => state.preference.conversations; export const selectConversationId = (state: RootState) => state.conversation.conversationId; +export const selectToken = (state: RootState) => state.preference.token; export const selectPrompt = (state: RootState) => state.preference.prompt; export const selectChunks = (state: RootState) => state.preference.chunks; export const selectTokenLimit = (state: RootState) => diff --git a/frontend/src/settings/APIKeys.tsx b/frontend/src/settings/APIKeys.tsx index b892787e..2da36c76 100644 --- a/frontend/src/settings/APIKeys.tsx +++ b/frontend/src/settings/APIKeys.tsx @@ -1,17 +1,20 @@ import React, { useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; import userService from '../api/services/userService'; import Trash from '../assets/trash.svg'; -import CreateAPIKeyModal from '../modals/CreateAPIKeyModal'; -import SaveAPIKeyModal from '../modals/SaveAPIKeyModal'; -import ConfirmationModal from '../modals/ConfirmationModal'; -import { APIKeyData } from './types'; import SkeletonLoader from '../components/SkeletonLoader'; import { useLoaderState } from '../hooks'; +import ConfirmationModal from '../modals/ConfirmationModal'; +import CreateAPIKeyModal from '../modals/CreateAPIKeyModal'; +import SaveAPIKeyModal from '../modals/SaveAPIKeyModal'; +import { selectToken } from '../preferences/preferenceSlice'; +import { APIKeyData } from './types'; export default function APIKeys() { const { t } = useTranslation(); + const token = useSelector(selectToken); const [isCreateModalOpen, setCreateModal] = useState(false); const [isSaveKeyModalOpen, setSaveKeyModal] = useState(false); const [newKey, setNewKey] = useState(''); @@ -25,7 +28,7 @@ export default function APIKeys() { const handleFetchKeys = async () => { setLoading(true); try { - const response = await userService.getAPIKeys(); + const response = await userService.getAPIKeys(token); if (!response.ok) { throw new Error('Failed to fetch API Keys'); } @@ -41,7 +44,7 @@ export default function APIKeys() { const handleDeleteKey = (id: string) => { setLoading(true); userService - .deleteAPIKey({ id }) + .deleteAPIKey({ id }, token) .then((response) => { if (!response.ok) { throw new Error('Failed to delete API Key'); @@ -71,7 +74,7 @@ export default function APIKeys() { }) => { setLoading(true); userService - .createAPIKey(payload) + .createAPIKey(payload, token) .then((response) => { if (!response.ok) { throw new Error('Failed to create API Key'); diff --git a/frontend/src/settings/Analytics.tsx b/frontend/src/settings/Analytics.tsx index 5ab95bac..a75d9aaf 100644 --- a/frontend/src/settings/Analytics.tsx +++ b/frontend/src/settings/Analytics.tsx @@ -1,5 +1,3 @@ -import React, { useState, useEffect } from 'react'; -import { useTranslation } from 'react-i18next'; import { BarElement, CategoryScale, @@ -9,18 +7,21 @@ import { Title, Tooltip, } from 'chart.js'; +import React, { useEffect, useState } from 'react'; import { Bar } from 'react-chartjs-2'; +import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; import userService from '../api/services/userService'; import Dropdown from '../components/Dropdown'; +import SkeletonLoader from '../components/SkeletonLoader'; +import { useLoaderState } from '../hooks'; +import { selectToken } from '../preferences/preferenceSlice'; import { htmlLegendPlugin } from '../utils/chartUtils'; import { formatDate } from '../utils/dateTimeUtils'; import { APIKeyData } from './types'; -import { useLoaderState } from '../hooks'; import type { ChartData } from 'chart.js'; -import SkeletonLoader from '../components/SkeletonLoader'; - ChartJS.register( CategoryScale, LinearScale, @@ -32,6 +33,7 @@ ChartJS.register( export default function Analytics() { const { t } = useTranslation(); + const token = useSelector(selectToken); const filterOptions = [ { label: t('settings.analytics.filterOptions.hour'), value: 'last_hour' }, @@ -97,7 +99,7 @@ export default function Analytics() { const fetchChatbots = async () => { setLoadingChatbots(true); try { - const response = await userService.getAPIKeys(); + const response = await userService.getAPIKeys(token); if (!response.ok) { throw new Error('Failed to fetch Chatbots'); } @@ -113,10 +115,13 @@ export default function Analytics() { const fetchMessagesData = async (chatbot_id?: string, filter?: string) => { setLoadingMessages(true); try { - const response = await userService.getMessageAnalytics({ - api_key_id: chatbot_id, - filter_option: filter, - }); + const response = await userService.getMessageAnalytics( + { + api_key_id: chatbot_id, + filter_option: filter, + }, + token, + ); if (!response.ok) { throw new Error('Failed to fetch analytics data'); } @@ -132,10 +137,13 @@ export default function Analytics() { const fetchTokenData = async (chatbot_id?: string, filter?: string) => { setLoadingTokens(true); try { - const response = await userService.getTokenAnalytics({ - api_key_id: chatbot_id, - filter_option: filter, - }); + const response = await userService.getTokenAnalytics( + { + api_key_id: chatbot_id, + filter_option: filter, + }, + token, + ); if (!response.ok) { throw new Error('Failed to fetch analytics data'); } @@ -151,10 +159,13 @@ export default function Analytics() { const fetchFeedbackData = async (chatbot_id?: string, filter?: string) => { setLoadingFeedback(true); try { - const response = await userService.getFeedbackAnalytics({ - api_key_id: chatbot_id, - filter_option: filter, - }); + const response = await userService.getFeedbackAnalytics( + { + api_key_id: chatbot_id, + filter_option: filter, + }, + token, + ); if (!response.ok) { throw new Error('Failed to fetch analytics data'); } diff --git a/frontend/src/settings/Documents.tsx b/frontend/src/settings/Documents.tsx index 2b29ef08..0fc03f24 100644 --- a/frontend/src/settings/Documents.tsx +++ b/frontend/src/settings/Documents.tsx @@ -1,6 +1,6 @@ import React, { useCallback, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useDispatch } from 'react-redux'; +import { useDispatch, useSelector } from 'react-redux'; import userService from '../api/services/userService'; import ArrowLeft from '../assets/arrow-left.svg'; @@ -22,6 +22,7 @@ import ConfirmationModal from '../modals/ConfirmationModal'; import { ActiveState, Doc, DocumentsProps } from '../models/misc'; import { getDocs, getDocsWithPagination } from '../preferences/preferenceApi'; import { + selectToken, setPaginatedDocuments, setSourceDocs, } from '../preferences/preferenceSlice'; @@ -53,6 +54,7 @@ export default function Documents({ }: DocumentsProps) { const { t } = useTranslation(); const dispatch = useDispatch(); + const token = useSelector(selectToken); const [searchTerm, setSearchTerm] = useState(''); const [modalState, setModalState] = useState('INACTIVE'); @@ -163,6 +165,7 @@ export default function Documents({ page, rowsPerPg, searchTerm, + token, ) .then((data) => { dispatch(setPaginatedDocuments(data ? data.docs : [])); @@ -179,9 +182,9 @@ export default function Documents({ const handleManageSync = (doc: Doc, sync_frequency: string) => { setLoading(true); userService - .manageSync({ source_id: doc.id, sync_frequency }) + .manageSync({ source_id: doc.id, sync_frequency }, token) .then(() => { - return getDocs(); + return getDocs(token); }) .then((data) => { dispatch(setSourceDocs(data)); @@ -190,6 +193,8 @@ export default function Documents({ sortOrder, currentPage, rowsPerPage, + searchTerm, + token, ); }) .then((paginatedData) => { @@ -519,6 +524,7 @@ function DocumentChunks({ handleGoBack: () => void; }) { const { t } = useTranslation(); + const token = useSelector(selectToken); const [isDarkTheme] = useDarkTheme(); const [paginatedChunks, setPaginatedChunks] = useState([]); const [page, setPage] = useState(1); @@ -536,7 +542,7 @@ function DocumentChunks({ setLoading(true); try { userService - .getDocumentChunks(document.id ?? '', page, perPage) + .getDocumentChunks(document.id ?? '', page, perPage, token) .then((response) => { if (!response.ok) { setLoading(false); @@ -561,13 +567,16 @@ function DocumentChunks({ const handleAddChunk = (title: string, text: string) => { try { userService - .addChunk({ - id: document.id ?? '', - text: text, - metadata: { - title: title, + .addChunk( + { + id: document.id ?? '', + text: text, + metadata: { + title: title, + }, }, - }) + token, + ) .then((response) => { if (!response.ok) { throw new Error('Failed to add chunk'); @@ -582,14 +591,17 @@ function DocumentChunks({ const handleUpdateChunk = (title: string, text: string, chunk: ChunkType) => { try { userService - .updateChunk({ - id: document.id ?? '', - chunk_id: chunk.doc_id, - text: text, - metadata: { - title: title, + .updateChunk( + { + id: document.id ?? '', + chunk_id: chunk.doc_id, + text: text, + metadata: { + title: title, + }, }, - }) + token, + ) .then((response) => { if (!response.ok) { throw new Error('Failed to update chunk'); @@ -604,7 +616,7 @@ function DocumentChunks({ const handleDeleteChunk = (chunk: ChunkType) => { try { userService - .deleteChunk(document.id ?? '', chunk.doc_id) + .deleteChunk(document.id ?? '', chunk.doc_id, token) .then((response) => { if (!response.ok) { throw new Error('Failed to delete chunk'); diff --git a/frontend/src/settings/General.tsx b/frontend/src/settings/General.tsx index 210f6bbc..fa64507e 100644 --- a/frontend/src/settings/General.tsx +++ b/frontend/src/settings/General.tsx @@ -8,6 +8,7 @@ import { useDarkTheme } from '../hooks'; import { selectChunks, selectPrompt, + selectToken, selectTokenLimit, setChunks, setModalStateDeleteConv, @@ -21,6 +22,7 @@ export default function General() { t, i18n: { changeLanguage }, } = useTranslation(); + const token = useSelector(selectToken); const themes = [ { value: 'Light', label: t('settings.general.light') }, { value: 'Dark', label: t('settings.general.dark') }, @@ -64,7 +66,7 @@ export default function General() { React.useEffect(() => { const handleFetchPrompts = async () => { try { - const response = await userService.getPrompts(); + const response = await userService.getPrompts(token); if (!response.ok) { throw new Error('Failed to fetch prompts'); } diff --git a/frontend/src/settings/Logs.tsx b/frontend/src/settings/Logs.tsx index 24cf3a6d..2507c106 100644 --- a/frontend/src/settings/Logs.tsx +++ b/frontend/src/settings/Logs.tsx @@ -1,5 +1,6 @@ -import React, { useState, useEffect, useRef, useCallback } from 'react'; +import React, { useCallback, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; import userService from '../api/services/userService'; import ChevronRight from '../assets/chevron-right.svg'; @@ -7,10 +8,12 @@ import CopyButton from '../components/CopyButton'; import Dropdown from '../components/Dropdown'; import SkeletonLoader from '../components/SkeletonLoader'; import { useLoaderState } from '../hooks'; +import { selectToken } from '../preferences/preferenceSlice'; import { APIKeyData, LogData } from './types'; export default function Logs() { const { t } = useTranslation(); + const token = useSelector(selectToken); const [chatbots, setChatbots] = useState([]); const [selectedChatbot, setSelectedChatbot] = useState(); const [logs, setLogs] = useState([]); @@ -22,7 +25,7 @@ export default function Logs() { const fetchChatbots = async () => { setLoadingChatbots(true); try { - const response = await userService.getAPIKeys(); + const response = await userService.getAPIKeys(token); if (!response.ok) { throw new Error('Failed to fetch Chatbots'); } @@ -38,11 +41,14 @@ export default function Logs() { const fetchLogs = async () => { setLoadingLogs(true); try { - const response = await userService.getLogs({ - page: page, - api_key_id: selectedChatbot?.id, - page_size: 10, - }); + const response = await userService.getLogs( + { + page: page, + api_key_id: selectedChatbot?.id, + page_size: 10, + }, + token, + ); if (!response.ok) { throw new Error('Failed to fetch logs'); } diff --git a/frontend/src/settings/Prompts.tsx b/frontend/src/settings/Prompts.tsx index 654b610a..33540296 100644 --- a/frontend/src/settings/Prompts.tsx +++ b/frontend/src/settings/Prompts.tsx @@ -1,9 +1,11 @@ import React from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; import userService from '../api/services/userService'; import Dropdown from '../components/Dropdown'; import { ActiveState, PromptProps } from '../models/misc'; +import { selectToken } from '../preferences/preferenceSlice'; import PromptsModal from '../preferences/PromptsModal'; export default function Prompts({ @@ -24,6 +26,7 @@ export default function Prompts({ setEditPromptName(name); onSelectPrompt(name, id, type); }; + const token = useSelector(selectToken); const [newPromptName, setNewPromptName] = React.useState(''); const [newPromptContent, setNewPromptContent] = React.useState(''); const [editPromptName, setEditPromptName] = React.useState(''); @@ -39,10 +42,13 @@ export default function Prompts({ const handleAddPrompt = async () => { try { - const response = await userService.createPrompt({ - name: newPromptName, - content: newPromptContent, - }); + const response = await userService.createPrompt( + { + name: newPromptName, + content: newPromptContent, + }, + token, + ); if (!response.ok) { throw new Error('Failed to add prompt'); } @@ -65,7 +71,7 @@ export default function Prompts({ const handleDeletePrompt = (id: string) => { setPrompts(prompts.filter((prompt) => prompt.id !== id)); userService - .deletePrompt({ id }) + .deletePrompt({ id }, token) .then((response) => { if (!response.ok) { throw new Error('Failed to delete prompt'); @@ -81,7 +87,7 @@ export default function Prompts({ const handleFetchPromptContent = async (id: string) => { try { - const response = await userService.getSinglePrompt(id); + const response = await userService.getSinglePrompt(id, token); if (!response.ok) { throw new Error('Failed to fetch prompt content'); } @@ -94,11 +100,14 @@ export default function Prompts({ const handleSaveChanges = (id: string, type: string) => { userService - .updatePrompt({ - id: id, - name: editPromptName, - content: editPromptContent, - }) + .updatePrompt( + { + id: id, + name: editPromptName, + content: editPromptContent, + }, + token, + ) .then((response) => { if (!response.ok) { throw new Error('Failed to update prompt'); diff --git a/frontend/src/settings/ToolConfig.tsx b/frontend/src/settings/ToolConfig.tsx index af57db21..d75a3852 100644 --- a/frontend/src/settings/ToolConfig.tsx +++ b/frontend/src/settings/ToolConfig.tsx @@ -1,4 +1,6 @@ import React from 'react'; +import { useSelector } from 'react-redux'; + import userService from '../api/services/userService'; import ArrowLeft from '../assets/arrow-left.svg'; import CircleCheck from '../assets/circle-check.svg'; @@ -9,6 +11,7 @@ import Input from '../components/Input'; import ToggleSwitch from '../components/ToggleSwitch'; import AddActionModal from '../modals/AddActionModal'; import { ActiveState } from '../models/misc'; +import { selectToken } from '../preferences/preferenceSlice'; import { APIActionType, APIToolType, UserToolType } from './types'; import { useTranslation } from 'react-i18next'; @@ -21,6 +24,7 @@ export default function ToolConfig({ setTool: (tool: UserToolType | APIToolType) => void; handleGoBack: () => void; }) { + const token = useSelector(selectToken); const [authKey, setAuthKey] = React.useState( 'token' in tool.config ? tool.config.token : '', ); @@ -57,22 +61,25 @@ export default function ToolConfig({ const handleSaveChanges = () => { userService - .updateTool({ - id: tool.id, - name: tool.name, - displayName: tool.displayName, - description: tool.description, - config: tool.name === 'api_tool' ? tool.config : { token: authKey }, - actions: 'actions' in tool ? tool.actions : [], - status: tool.status, - }) + .updateTool( + { + id: tool.id, + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: tool.name === 'api_tool' ? tool.config : { token: authKey }, + actions: 'actions' in tool ? tool.actions : [], + status: tool.status, + }, + token, + ) .then(() => { handleGoBack(); }); }; const handleDelete = () => { - userService.deleteTool({ id: tool.id }).then(() => { + userService.deleteTool({ id: tool.id }, token).then(() => { handleGoBack(); }); }; diff --git a/frontend/src/settings/Tools.tsx b/frontend/src/settings/Tools.tsx index 7432ecf0..b42195ed 100644 --- a/frontend/src/settings/Tools.tsx +++ b/frontend/src/settings/Tools.tsx @@ -1,18 +1,22 @@ import React from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelector } from 'react-redux'; import userService from '../api/services/userService'; import CogwheelIcon from '../assets/cogwheel.svg'; import Input from '../components/Input'; import Spinner from '../components/Spinner'; +import ToggleSwitch from '../components/ToggleSwitch'; import AddToolModal from '../modals/AddToolModal'; import { ActiveState } from '../models/misc'; +import { selectToken } from '../preferences/preferenceSlice'; import ToolConfig from './ToolConfig'; import { APIToolType, UserToolType } from './types'; -import ToggleSwitch from '../components/ToggleSwitch'; export default function Tools() { const { t } = useTranslation(); + const token = useSelector(selectToken); + const [searchTerm, setSearchTerm] = React.useState(''); const [addToolModalState, setAddToolModalState] = React.useState('INACTIVE'); @@ -25,7 +29,7 @@ export default function Tools() { const getUserTools = () => { setLoading(true); userService - .getUserTools() + .getUserTools(token) .then((res) => { return res.json(); }) @@ -41,7 +45,7 @@ export default function Tools() { const updateToolStatus = (toolId: string, newStatus: boolean) => { userService - .updateToolStatus({ id: toolId, status: newStatus }) + .updateToolStatus({ id: toolId, status: newStatus }, token) .then(() => { setUserTools((prevTools) => prevTools.map((tool) => @@ -65,7 +69,7 @@ export default function Tools() { const handleToolAdded = (toolId: string) => { userService - .getUserTools() + .getUserTools(token) .then((res) => res.json()) .then((data) => { const newTool = data.tools.find( diff --git a/frontend/src/settings/index.tsx b/frontend/src/settings/index.tsx index 918e4d15..cd504858 100644 --- a/frontend/src/settings/index.tsx +++ b/frontend/src/settings/index.tsx @@ -11,6 +11,7 @@ import { selectSourceDocs, setPaginatedDocuments, setSourceDocs, + selectToken, } from '../preferences/preferenceSlice'; import Analytics from './Analytics'; import APIKeys from './APIKeys'; @@ -28,6 +29,7 @@ export default function Settings() { null, ); + const token = useSelector(selectToken); const documents = useSelector(selectSourceDocs); const paginatedDocuments = useSelector(selectPaginatedDocuments); const updateWidgetScreenshot = (screenshot: File | null) => { @@ -41,7 +43,7 @@ export default function Settings() { const handleDeleteClick = (index: number, doc: Doc) => { userService - .deletePath(doc.id ?? '') + .deletePath(doc.id ?? '', token) .then((response) => { if (response.ok && documents) { if (paginatedDocuments) { diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 8f426ed6..02aa9a68 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -16,6 +16,7 @@ const doc = localStorage.getItem('DocsGPTRecentDocs'); const preloadedState: { preference: Preference } = { preference: { apiKey: key ?? '', + token: localStorage.getItem('authToken') ?? null, prompt: prompt !== null ? JSON.parse(prompt) diff --git a/frontend/src/upload/Upload.tsx b/frontend/src/upload/Upload.tsx index df06af2f..e70c930f 100644 --- a/frontend/src/upload/Upload.tsx +++ b/frontend/src/upload/Upload.tsx @@ -9,21 +9,22 @@ import WebsiteCollect from '../assets/website_collect.svg'; import Dropdown from '../components/Dropdown'; import Input from '../components/Input'; import ToggleSwitch from '../components/ToggleSwitch'; +import WrapperModal from '../modals/WrapperModal'; import { ActiveState, Doc } from '../models/misc'; import { getDocs } from '../preferences/preferenceApi'; import { + selectSourceDocs, + selectToken, setSelectedDocs, setSourceDocs, - selectSourceDocs, } from '../preferences/preferenceSlice'; -import WrapperModal from '../modals/WrapperModal'; +import { IngestorDefaultConfigs } from '../upload/types/ingestor'; import { - IngestorType, + FormField, IngestorConfig, IngestorFormSchemas, - FormField, + IngestorType, } from './types/ingestor'; -import { IngestorDefaultConfigs } from '../upload/types/ingestor'; function Upload({ receivedFile = [], @@ -40,6 +41,7 @@ function Upload({ close: () => void; onSuccessfulUpload?: () => void; }) { + const token = useSelector(selectToken); const [docName, setDocName] = useState(receivedFile[0]?.name); const [remoteName, setRemoteName] = useState(''); const [files, setfiles] = useState(receivedFile); @@ -297,12 +299,12 @@ function Upload({ if ((progress?.percentage ?? 0) < 100) { timeoutID = setTimeout(() => { userService - .getTaskStatus(progress?.taskId as string) + .getTaskStatus(progress?.taskId as string, null) .then((data) => data.json()) .then((data) => { if (data.status == 'SUCCESS') { if (data.result.limited === true) { - getDocs().then((data) => { + getDocs(token).then((data) => { dispatch(setSourceDocs(data)); dispatch( setSelectedDocs( @@ -322,7 +324,7 @@ function Upload({ }, ); } else { - getDocs().then((data) => { + getDocs(token).then((data) => { dispatch(setSourceDocs(data)); const docIds = new Set( (Array.isArray(sourceDocs) && @@ -413,6 +415,7 @@ function Upload({ }, 3000); }; xhr.open('POST', `${apiHost + '/api/upload'}`); + xhr.setRequestHeader('Authorization', `Bearer ${token}`); xhr.send(formData); };