diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 5a221e8d..ebe1e971 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -122,6 +122,7 @@ def save_conversation( source_log_docs, tool_calls, llm, + decoded_token, index=None, api_key=None, ): @@ -180,7 +181,7 @@ def save_conversation( completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) conversation_data = { - "user": "local", + "user": decoded_token.get("sub"), "date": datetime.datetime.utcnow(), "name": completion, "queries": [ @@ -221,6 +222,7 @@ def complete_stream( retriever, conversation_id, user_api_key, + decoded_token, isNoneDoc=False, index=None, should_save_conversation=True, @@ -271,6 +273,7 @@ def complete_stream( source_log_docs, tool_calls, llm, + decoded_token, index, api_key=user_api_key, ) @@ -286,7 +289,7 @@ def complete_stream( { "action": "stream_answer", "level": "info", - "user": "local", + "user": decoded_token.get("sub"), "api_key": user_api_key, "question": question, "response": response_full, @@ -381,15 +384,21 @@ class Stream(Resource): source = {"active_docs": data_key.get("source")} retriever_name = data_key.get("retriever", retriever_name) user_api_key = data["api_key"] + decoded_token = {"sub": data_key.get("user")} elif "active_docs" in data: source = {"active_docs": data["active_docs"]} retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None + decoded_token = request.decoded_token else: source = {} user_api_key = None + decoded_token = request.decoded_token + + if not decoded_token: + return bad_request(401, "Unauthorized") logger.info( f"/stream - request_data: {data}, source: {source}", @@ -429,6 +438,7 @@ class Stream(Resource): retriever=retriever, conversation_id=conversation_id, user_api_key=user_api_key, + decoded_token=decoded_token, isNoneDoc=data.get("isNoneDoc"), index=index, should_save_conversation=save_conv, @@ -521,13 +531,21 @@ class Answer(Resource): source = {"active_docs": data_key.get("source")} retriever_name = data_key.get("retriever", retriever_name) user_api_key = data["api_key"] + decoded_token = {"sub": data_key.get("user")} + elif "active_docs" in data: source = {"active_docs": data["active_docs"]} retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None + decoded_token = request.decoded_token + else: source = {} user_api_key = None + decoded_token = request.decoded_token + + if not decoded_token: + return bad_request(401, "Unauthorized") prompt = get_prompt(prompt_id) @@ -614,6 +632,7 @@ class Answer(Resource): source_log_docs, tool_calls, llm, + decoded_token, api_key=user_api_key, ) ) @@ -623,7 +642,7 @@ class Answer(Resource): { "action": "api_answer", "level": "info", - "user": "local", + "user": decoded_token.get("sub"), "api_key": user_api_key, "question": question, "response": response_full, @@ -692,12 +711,17 @@ class Search(Resource): chunks = int(data_key.get("chunks", 2)) source = {"active_docs": data_key.get("source")} user_api_key = data["api_key"] + decoded_token = {"sub": data_key.get("user")} + elif "active_docs" in data: source = {"active_docs": data["active_docs"]} user_api_key = None + decoded_token = request.decoded_token + else: source = {} user_api_key = None + decoded_token = request.decoded_token logger.info( f"/api/answer - request_data: {data}, source: {source}", @@ -723,7 +747,7 @@ class Search(Resource): { "action": "api_search", "level": "info", - "user": "local", + "user": decoded_token.get("sub"), "api_key": user_api_key, "question": question, "sources": docs, diff --git a/application/api/user/routes.py b/application/api/user/routes.py index d7fb4d89..9d1bfadc 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -15,7 +15,6 @@ from werkzeug.utils import secure_filename from application.agents.tools.tool_manager import ToolManager from application.api.user.tasks import ingest, ingest_remote - from application.core.mongo_db import MongoDB from application.core.settings import settings from application.extensions import api @@ -110,11 +109,18 @@ class GetConversations(Resource): description="Retrieve a list of the latest 30 conversations (excluding API key conversations)", ) def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) try: - conversations = conversations_collection.find( - {"api_key": {"$exists": False}} - ).sort("date", -1).limit(30) - + conversations = ( + conversations_collection.find( + {"api_key": {"$exists": False}, "user": decoded_token.get("sub")} + ) + .sort("date", -1) + .limit(30) + ) + list_conversations = [ {"id": str(conversation["_id"]), "name": conversation["name"]} for conversation in conversations @@ -132,6 +138,9 @@ class GetSingleConversation(Resource): params={"id": "The conversation ID"}, ) def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) conversation_id = request.args.get("id") if not conversation_id: return make_response( @@ -140,7 +149,7 @@ class GetSingleConversation(Resource): try: conversation = conversations_collection.find_one( - {"_id": ObjectId(conversation_id)} + {"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")} ) if not conversation: return make_response(jsonify({"status": "not found"}), 404) @@ -227,7 +236,7 @@ class SubmitFeedback(Resource): { "$unset": { f"queries.{data['question_index']}.feedback": "", - f"queries.{data['question_index']}.feedback_timestamp": "" + f"queries.{data['question_index']}.feedback_timestamp": "", } }, ) @@ -240,8 +249,12 @@ class SubmitFeedback(Resource): }, { "$set": { - f"queries.{data['question_index']}.feedback": data["feedback"], - f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(datetime.timezone.utc) + f"queries.{data['question_index']}.feedback": data[ + "feedback" + ], + f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now( + datetime.timezone.utc + ), } }, ) @@ -1211,7 +1224,13 @@ class GetMessageAnalytics(Resource): required=False, description="Filter option for analytics", default="last_30_days", - enum=["last_hour", "last_24_hour", "last_7_days", "last_15_days", "last_30_days"], + enum=[ + "last_hour", + "last_24_hour", + "last_7_days", + "last_15_days", + "last_30_days", + ], ), }, ) @@ -1244,9 +1263,9 @@ class GetMessageAnalytics(Resource): else: if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: filter_days = ( - 6 if filter_option == "last_7_days" - else 14 if filter_option == "last_15_days" - else 29 + 6 + if filter_option == "last_7_days" + else 14 if filter_option == "last_15_days" else 29 ) else: return make_response( @@ -1254,25 +1273,20 @@ class GetMessageAnalytics(Resource): ) start_date = end_date - datetime.timedelta(days=filter_days) start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) - end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999) + end_date = end_date.replace( + hour=23, minute=59, second=59, microsecond=999999 + ) group_format = "%Y-%m-%d" try: pipeline = [ # Initial match for API key if provided - { - "$match": { - "api_key": api_key if api_key else {"$exists": False} - } - }, + {"$match": {"api_key": api_key if api_key else {"$exists": False}}}, {"$unwind": "$queries"}, # Match queries within the time range { "$match": { - "queries.timestamp": { - "$gte": start_date, - "$lte": end_date - } + "queries.timestamp": {"$gte": start_date, "$lte": end_date} } }, # Group by formatted timestamp @@ -1281,14 +1295,14 @@ class GetMessageAnalytics(Resource): "_id": { "$dateToString": { "format": group_format, - "date": "$queries.timestamp" + "date": "$queries.timestamp", } }, - "count": {"$sum": 1} + "count": {"$sum": 1}, } }, # Sort by timestamp - {"$sort": {"_id": 1}} + {"$sort": {"_id": 1}}, ] message_data = conversations_collection.aggregate(pipeline) @@ -1511,11 +1525,21 @@ class GetFeedbackAnalytics(Resource): if filter_option == "last_hour": start_date = end_date - datetime.timedelta(hours=1) group_format = "%Y-%m-%d %H:%M:00" - date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}} + date_field = { + "$dateToString": { + "format": group_format, + "date": "$queries.feedback_timestamp", + } + } elif filter_option == "last_24_hour": start_date = end_date - datetime.timedelta(hours=24) group_format = "%Y-%m-%d %H:00" - date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}} + date_field = { + "$dateToString": { + "format": group_format, + "date": "$queries.feedback_timestamp", + } + } else: if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: filter_days = ( @@ -1533,13 +1557,21 @@ class GetFeedbackAnalytics(Resource): hour=23, minute=59, second=59, microsecond=999999 ) group_format = "%Y-%m-%d" - date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}} + date_field = { + "$dateToString": { + "format": group_format, + "date": "$queries.feedback_timestamp", + } + } try: match_stage = { "$match": { - "queries.feedback_timestamp": {"$gte": start_date, "$lte": end_date}, - "queries.feedback": {"$exists": True} + "queries.feedback_timestamp": { + "$gte": start_date, + "$lte": end_date, + }, + "queries.feedback": {"$exists": True}, } } if api_key: diff --git a/application/app.py b/application/app.py index 4eb40331..ddac00b0 100644 --- a/application/app.py +++ b/application/app.py @@ -1,20 +1,26 @@ import platform import dotenv -from flask import Flask, redirect, request +from flask import Flask, jsonify, redirect, request +from jose import jwt + +from application.auth import get_or_create_user_id, handle_auth + from application.core.logging_config import setup_logging + setup_logging() -from application.api.answer.routes import answer # noqa: E402 -from application.api.internal.routes import internal # noqa: E402 -from application.api.user.routes import user # noqa: E402 -from application.celery_init import celery # noqa: E402 -from application.core.settings import settings # noqa: E402 -from application.extensions import api # noqa: E402 +from application.api.answer.routes import answer # noqa: E402 +from application.api.internal.routes import internal # noqa: E402 +from application.api.user.routes import user # noqa: E402 +from application.celery_init import celery # noqa: E402 +from application.core.settings import settings # noqa: E402 +from application.extensions import api # noqa: E402 if platform.system() == "Windows": import pathlib + pathlib.PosixPath = pathlib.WindowsPath dotenv.load_dotenv() @@ -32,6 +38,13 @@ app.config.update( celery.config_from_object("application.celeryconfig") api.init_app(app) +SIMPLE_JWT_TOKEN = None +if settings.AUTH_TYPE == "simple_jwt": + user_id = get_or_create_user_id() + payload = {"sub": user_id} + SIMPLE_JWT_TOKEN = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm="HS256") + print(f"Generated Simple JWT Token: {SIMPLE_JWT_TOKEN}") + @app.route("/") def home(): @@ -41,11 +54,27 @@ def home(): return "Welcome to DocsGPT Backend!" +@app.before_request +def authenticate_request(): + if request.method == "OPTIONS": + return "", 200 + + decoded_token = handle_auth(request) + if "message" in decoded_token: + request.decoded_token = None + elif "error" in decoded_token: + return jsonify(decoded_token), 401 + else: + request.decoded_token = decoded_token + + @app.after_request def after_request(response): response.headers.add("Access-Control-Allow-Origin", "*") - response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization") - response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS") + response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization") + response.headers.add( + "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS" + ) return response diff --git a/application/auth.py b/application/auth.py new file mode 100644 index 00000000..be2501b0 --- /dev/null +++ b/application/auth.py @@ -0,0 +1,39 @@ +import uuid + +from jose import jwt + +from application.core.settings import settings + + +def handle_auth(request, data={}): + if settings.AUTH_TYPE == "simple_jwt": + jwt_token = request.headers.get("Authorization") + if not jwt_token: + return {"message": "Missing Authorization header"} + + jwt_token = jwt_token.replace("Bearer ", "") + + try: + decoded_token = jwt.decode( + jwt_token, + settings.JWT_SECRET_KEY, + algorithms=["HS256"], + options={"verify_exp": False}, + ) + return decoded_token + except Exception as e: + return {"message": f"Authentication error: {str(e)}"} + else: + return {"sub": "local"} + + +def get_or_create_user_id(): + try: + with open(settings.USER_ID_FILE, "r") as f: + user_id = f.read().strip() + return user_id + except FileNotFoundError: + user_id = str(uuid.uuid4()) + with open(settings.USER_ID_FILE, "w") as f: + f.write(user_id) + return user_id diff --git a/application/core/settings.py b/application/core/settings.py index 04d7bbea..a757a7f3 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -10,6 +10,7 @@ current_dir = os.path.dirname( class Settings(BaseSettings): + AUTH_TYPE: Optional[str] = "simple_jwt" LLM_NAME: str = "docsgpt" MODEL_NAME: Optional[str] = ( None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo @@ -98,6 +99,9 @@ class Settings(BaseSettings): FLASK_DEBUG_MODE: bool = False + JWT_SECRET_KEY: str = "" + USER_ID_FILE: str = os.path.join(current_dir, "user_id.txt") + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/requirements.txt b/application/requirements.txt index c912d373..c00c24f7 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -69,6 +69,7 @@ pymongo==4.10.1 pypdf==5.2.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 +python-jose==3.4.0 python-pptx==1.0.2 qdrant-client==1.13.2 redis==5.2.1 diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx index bba83037..ad30228b 100644 --- a/frontend/src/Navigation.tsx +++ b/frontend/src/Navigation.tsx @@ -2,28 +2,34 @@ import { useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useDispatch, useSelector } from 'react-redux'; import { NavLink, useNavigate } from 'react-router-dom'; + import conversationService from './api/services/conversationService'; import userService from './api/services/userService'; import Add from './assets/add.svg'; -import openNewChat from './assets/openNewChat.svg'; -import Hamburger from './assets/hamburger.svg'; import DocsGPT3 from './assets/cute_docsgpt3.svg'; import Discord from './assets/discord.svg'; import Expand from './assets/expand.svg'; import Github from './assets/github.svg'; +import Hamburger from './assets/hamburger.svg'; +import openNewChat from './assets/openNewChat.svg'; import SettingGear from './assets/settingGear.svg'; +import SpinnerDark from './assets/spinner-dark.svg'; +import Spinner from './assets/spinner.svg'; import Twitter from './assets/TwitterX.svg'; import UploadIcon from './assets/upload.svg'; +import Help from './components/Help'; import SourceDropdown from './components/SourceDropdown'; import { + handleAbort, + selectQueries, setConversation, updateConversationId, - handleAbort, } from './conversation/conversationSlice'; import ConversationTile from './conversation/ConversationTile'; import { useDarkTheme, useMediaQuery } from './hooks'; import useDefaultDocument from './hooks/useDefaultDocument'; import DeleteConvModal from './modals/DeleteConvModal'; +import JWTModal from './modals/JWTModal'; import { ActiveState, Doc } from './models/misc'; import { getConversations, getDocs } from './preferences/preferenceApi'; import { @@ -31,20 +37,17 @@ import { selectConversationId, selectConversations, selectModalStateDeleteConv, + selectPaginatedDocuments, selectSelectedDocs, selectSourceDocs, - selectPaginatedDocuments, + selectToken, setConversations, setModalStateDeleteConv, + setPaginatedDocuments, setSelectedDocs, setSourceDocs, - setPaginatedDocuments, } from './preferences/preferenceSlice'; -import Spinner from './assets/spinner.svg'; -import SpinnerDark from './assets/spinner-dark.svg'; -import { selectQueries } from './conversation/conversationSlice'; import Upload from './upload/Upload'; -import Help from './components/Help'; interface NavigationProps { navOpen: boolean; @@ -53,6 +56,7 @@ interface NavigationProps { export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const dispatch = useDispatch(); + const token = useSelector(selectToken); const queries = useSelector(selectQueries); const docs = useSelector(selectSourceDocs); const selectedDocs = useSelector(selectSelectedDocs); @@ -70,6 +74,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const [uploadModalState, setUploadModalState] = useState('INACTIVE'); + const [authKeyModalState, setAuthKeyModalState] = + useState('INACTIVE'); const navRef = useRef(null); @@ -86,7 +92,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { async function fetchConversations() { dispatch(setConversations({ ...conversations, loading: true })); - return await getConversations() + return await getConversations(token) .then((fetchedConversations) => { dispatch(setConversations(fetchedConversations)); }) @@ -99,7 +105,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleDeleteAllConversations = () => { setIsDeletingConversation(true); conversationService - .deleteAll() + .deleteAll(token) .then(() => { fetchConversations(); }) @@ -109,7 +115,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleDeleteConversation = (id: string) => { setIsDeletingConversation(true); conversationService - .delete(id, {}) + .delete(id, {}, token) .then(() => { fetchConversations(); resetConversation(); @@ -119,9 +125,9 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleDeleteClick = (doc: Doc) => { userService - .deletePath(doc.id ?? '') + .deletePath(doc.id ?? '', token) .then(() => { - return getDocs(); + return getDocs(token); }) .then((updatedDocs) => { dispatch(setSourceDocs(updatedDocs)); @@ -145,7 +151,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleConversationClick = (index: string) => { conversationService - .getConversation(index) + .getConversation(index, token) .then((response) => response.json()) .then((data) => { navigate('/'); @@ -177,7 +183,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { id: string; }) { await conversationService - .update(updatedConversation) + .update(updatedConversation, token) .then((response) => response.json()) .then((data) => { if (data) { @@ -197,6 +203,14 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { useEffect(() => { setNavOpen(!isMobile); }, [isMobile]); + + useEffect(() => { + const authToken = localStorage.getItem('authToken'); + if (!authToken) { + setAuthKeyModalState('ACTIVE'); + } + }, []); + useDefaultDocument(); return ( @@ -470,6 +484,10 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { close={() => setUploadModalState('INACTIVE')} > )} + ); } diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 21699721..3db613fc 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -4,14 +4,24 @@ const defaultHeaders = { 'Content-Type': 'application/json', }; +const getHeaders = (token: string | null, customHeaders = {}): HeadersInit => { + return { + ...defaultHeaders, + ...(token ? { Authorization: `Bearer ${token}` } : {}), + ...customHeaders, + }; +}; + const apiClient = { - get: (url: string, headers = {}, signal?: AbortSignal): Promise => + get: ( + url: string, + token: string | null, + headers = {}, + signal?: AbortSignal, + ): Promise => fetch(`${baseURL}${url}`, { method: 'GET', - headers: { - ...defaultHeaders, - ...headers, - }, + headers: getHeaders(token, headers), signal, }).then((response) => { return response; @@ -20,15 +30,13 @@ const apiClient = { post: ( url: string, data: any, + token: string | null, headers = {}, signal?: AbortSignal, ): Promise => fetch(`${baseURL}${url}`, { method: 'POST', - headers: { - ...defaultHeaders, - ...headers, - }, + headers: getHeaders(token, headers), body: JSON.stringify(data), signal, }).then((response) => { @@ -38,28 +46,28 @@ const apiClient = { put: ( url: string, data: any, + token: string | null, headers = {}, signal?: AbortSignal, ): Promise => fetch(`${baseURL}${url}`, { method: 'PUT', - headers: { - ...defaultHeaders, - ...headers, - }, + headers: getHeaders(token, headers), body: JSON.stringify(data), signal, }).then((response) => { return response; }), - delete: (url: string, headers = {}, signal?: AbortSignal): Promise => + delete: ( + url: string, + token: string | null, + headers = {}, + signal?: AbortSignal, + ): Promise => fetch(`${baseURL}${url}`, { method: 'DELETE', - headers: { - ...defaultHeaders, - ...headers, - }, + headers: getHeaders(token, headers), signal, }).then((response) => { return response; diff --git a/frontend/src/api/services/conversationService.ts b/frontend/src/api/services/conversationService.ts index aaf703de..853a6863 100644 --- a/frontend/src/api/services/conversationService.ts +++ b/frontend/src/api/services/conversationService.ts @@ -2,31 +2,58 @@ import apiClient from '../client'; import endpoints from '../endpoints'; const conversationService = { - answer: (data: any, signal: AbortSignal): Promise => - apiClient.post(endpoints.CONVERSATION.ANSWER, data, {}, signal), - answerStream: (data: any, signal: AbortSignal): Promise => - apiClient.post(endpoints.CONVERSATION.ANSWER_STREAMING, data, {}, signal), - search: (data: any): Promise => - apiClient.post(endpoints.CONVERSATION.SEARCH, data), - feedback: (data: any): Promise => - apiClient.post(endpoints.CONVERSATION.FEEDBACK, data), - getConversation: (id: string): Promise => - apiClient.get(endpoints.CONVERSATION.CONVERSATION(id)), - getConversations: (): Promise => - apiClient.get(endpoints.CONVERSATION.CONVERSATIONS), - shareConversation: (isPromptable: boolean, data: any): Promise => + answer: ( + data: any, + token: string | null, + signal: AbortSignal, + ): Promise => + apiClient.post(endpoints.CONVERSATION.ANSWER, data, token, {}, signal), + answerStream: ( + data: any, + token: string | null, + signal: AbortSignal, + ): Promise => + apiClient.post( + endpoints.CONVERSATION.ANSWER_STREAMING, + data, + token, + {}, + signal, + ), + search: (data: any, token: string | null): Promise => + apiClient.post(endpoints.CONVERSATION.SEARCH, data, token, {}), + feedback: (data: any, token: string | null): Promise => + apiClient.post(endpoints.CONVERSATION.FEEDBACK, data, token, {}), + getConversation: (id: string, token: string | null): Promise => + apiClient.get(endpoints.CONVERSATION.CONVERSATION(id), token, {}), + getConversations: (token: string | null): Promise => + apiClient.get(endpoints.CONVERSATION.CONVERSATIONS, token, {}), + shareConversation: ( + isPromptable: boolean, + data: any, + token: string | null, + ): Promise => apiClient.post( endpoints.CONVERSATION.SHARE_CONVERSATION(isPromptable), data, + token, + {}, ), - getSharedConversation: (identifier: string): Promise => - apiClient.get(endpoints.CONVERSATION.SHARED_CONVERSATION(identifier)), - delete: (id: string, data: any): Promise => - apiClient.post(endpoints.CONVERSATION.DELETE(id), data), - deleteAll: (): Promise => - apiClient.get(endpoints.CONVERSATION.DELETE_ALL), - update: (data: any): Promise => - apiClient.post(endpoints.CONVERSATION.UPDATE, data), + getSharedConversation: ( + identifier: string, + token: string | null, + ): Promise => + apiClient.get( + endpoints.CONVERSATION.SHARED_CONVERSATION(identifier), + token, + {}, + ), + delete: (id: string, data: any, token: string | null): Promise => + apiClient.post(endpoints.CONVERSATION.DELETE(id), data, token, {}), + deleteAll: (token: string | null): Promise => + apiClient.get(endpoints.CONVERSATION.DELETE_ALL, token, {}), + update: (data: any, token: string | null): Promise => + apiClient.post(endpoints.CONVERSATION.UPDATE, data, token, {}), }; export default conversationService; diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index e7f367f1..fbaa06ba 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -2,63 +2,71 @@ import apiClient from '../client'; import endpoints from '../endpoints'; const userService = { - getDocs: (): Promise => apiClient.get(`${endpoints.USER.DOCS}`), - getDocsWithPagination: (query: string): Promise => - apiClient.get(`${endpoints.USER.DOCS_PAGINATED}?${query}`), - checkDocs: (data: any): Promise => - apiClient.post(endpoints.USER.DOCS_CHECK, data), - getAPIKeys: (): Promise => apiClient.get(endpoints.USER.API_KEYS), - createAPIKey: (data: any): Promise => - apiClient.post(endpoints.USER.CREATE_API_KEY, data), - deleteAPIKey: (data: any): Promise => - apiClient.post(endpoints.USER.DELETE_API_KEY, data), - getPrompts: (): Promise => apiClient.get(endpoints.USER.PROMPTS), - createPrompt: (data: any): Promise => - apiClient.post(endpoints.USER.CREATE_PROMPT, data), - deletePrompt: (data: any): Promise => - apiClient.post(endpoints.USER.DELETE_PROMPT, data), - updatePrompt: (data: any): Promise => - apiClient.post(endpoints.USER.UPDATE_PROMPT, data), - getSinglePrompt: (id: string): Promise => - apiClient.get(endpoints.USER.SINGLE_PROMPT(id)), - deletePath: (docPath: string): Promise => - apiClient.get(endpoints.USER.DELETE_PATH(docPath)), - getTaskStatus: (task_id: string): Promise => - apiClient.get(endpoints.USER.TASK_STATUS(task_id)), - getMessageAnalytics: (data: any): Promise => - apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data), - getTokenAnalytics: (data: any): Promise => - apiClient.post(endpoints.USER.TOKEN_ANALYTICS, data), - getFeedbackAnalytics: (data: any): Promise => - apiClient.post(endpoints.USER.FEEDBACK_ANALYTICS, data), - getLogs: (data: any): Promise => - apiClient.post(endpoints.USER.LOGS, data), - manageSync: (data: any): Promise => - apiClient.post(endpoints.USER.MANAGE_SYNC, data), - getAvailableTools: (): Promise => - apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS), - getUserTools: (): Promise => - apiClient.get(endpoints.USER.GET_USER_TOOLS), - createTool: (data: any): Promise => - apiClient.post(endpoints.USER.CREATE_TOOL, data), - updateToolStatus: (data: any): Promise => - apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data), - updateTool: (data: any): Promise => - apiClient.post(endpoints.USER.UPDATE_TOOL, data), - deleteTool: (data: any): Promise => - apiClient.post(endpoints.USER.DELETE_TOOL, data), + getDocs: (token: string | null): Promise => + apiClient.get(`${endpoints.USER.DOCS}`, token), + getDocsWithPagination: (query: string, token: string | null): Promise => + apiClient.get(`${endpoints.USER.DOCS_PAGINATED}?${query}`, token), + checkDocs: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.DOCS_CHECK, data, token), + getAPIKeys: (token: string | null): Promise => + apiClient.get(endpoints.USER.API_KEYS, token), + createAPIKey: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.CREATE_API_KEY, data, token), + deleteAPIKey: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.DELETE_API_KEY, data, token), + getPrompts: (token: string | null): Promise => + apiClient.get(endpoints.USER.PROMPTS, token), + createPrompt: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.CREATE_PROMPT, data, token), + deletePrompt: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.DELETE_PROMPT, data, token), + updatePrompt: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.UPDATE_PROMPT, data, token), + getSinglePrompt: (id: string, token: string | null): Promise => + apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token), + deletePath: (docPath: string, token: string | null): Promise => + apiClient.get(endpoints.USER.DELETE_PATH(docPath), token), + getTaskStatus: (task_id: string, token: string | null): Promise => + apiClient.get(endpoints.USER.TASK_STATUS(task_id), token), + getMessageAnalytics: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token), + getTokenAnalytics: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.TOKEN_ANALYTICS, data, token), + getFeedbackAnalytics: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.FEEDBACK_ANALYTICS, data, token), + getLogs: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.LOGS, data, token), + manageSync: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.MANAGE_SYNC, data, token), + getAvailableTools: (token: string | null): Promise => + apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS, token), + getUserTools: (token: string | null): Promise => + apiClient.get(endpoints.USER.GET_USER_TOOLS, token), + createTool: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.CREATE_TOOL, data, token), + updateToolStatus: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data, token), + updateTool: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL, data, token), + deleteTool: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.DELETE_TOOL, data, token), getDocumentChunks: ( docId: string, page: number, perPage: number, + token: string | null, ): Promise => - apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage)), - addChunk: (data: any): Promise => - apiClient.post(endpoints.USER.ADD_CHUNK, data), - deleteChunk: (docId: string, chunkId: string): Promise => - apiClient.delete(endpoints.USER.DELETE_CHUNK(docId, chunkId)), - updateChunk: (data: any): Promise => - apiClient.put(endpoints.USER.UPDATE_CHUNK, data), + apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage), token), + addChunk: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.ADD_CHUNK, data, token), + deleteChunk: ( + docId: string, + chunkId: string, + token: string | null, + ): Promise => + apiClient.delete(endpoints.USER.DELETE_CHUNK(docId, chunkId), token), + updateChunk: (data: any, token: string | null): Promise => + apiClient.put(endpoints.USER.UPDATE_CHUNK, data, token), }; export default userService; diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx index 29438f93..6198ccf2 100644 --- a/frontend/src/conversation/Conversation.tsx +++ b/frontend/src/conversation/Conversation.tsx @@ -15,7 +15,10 @@ import Spinner from '../assets/spinner.svg'; import RetryIcon from '../components/RetryIcon'; import { useDarkTheme, useMediaQuery } from '../hooks'; import { ShareConversationModal } from '../modals/ShareConversationModal'; -import { selectConversationId } from '../preferences/preferenceSlice'; +import { + selectConversationId, + selectToken, +} from '../preferences/preferenceSlice'; import { AppDispatch } from '../store'; import ConversationBubble from './ConversationBubble'; import { handleSendFeedback } from './conversationHandlers'; @@ -34,6 +37,7 @@ import Upload from '../upload/Upload'; import { ActiveState } from '../models/misc'; export default function Conversation() { + const token = useSelector(selectToken); const queries = useSelector(selectQueries); const navigate = useNavigate(); const status = useSelector(selectStatus); @@ -161,6 +165,7 @@ export default function Conversation() { feedback, conversationId as string, index, + token, ).catch(() => handleSendFeedback( query.prompt, @@ -168,6 +173,7 @@ export default function Conversation() { feedback, conversationId as string, index, + token, ).catch(() => dispatch(updateQuery({ index, query: { feedback: prevFeedback } })), ), diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index 0b54a366..88771fc5 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -6,6 +6,7 @@ import { ToolCallsType } from './types'; export function handleFetchAnswer( question: string, signal: AbortSignal, + token: string | null, selectedDocs: Doc | null, history: Array = [], conversationId: string | null, @@ -52,7 +53,7 @@ export function handleFetchAnswer( } payload.retriever = selectedDocs?.retriever as string; return conversationService - .answer(payload, signal) + .answer(payload, token, signal) .then((response) => { if (response.ok) { return response.json(); @@ -76,6 +77,7 @@ export function handleFetchAnswer( export function handleFetchAnswerSteaming( question: string, signal: AbortSignal, + token: string | null, selectedDocs: Doc | null, history: Array = [], conversationId: string | null, @@ -109,7 +111,7 @@ export function handleFetchAnswerSteaming( return new Promise((resolve, reject) => { conversationService - .answerStream(payload, signal) + .answerStream(payload, token, signal) .then((response) => { if (!response.body) throw Error('No response body'); @@ -160,6 +162,7 @@ export function handleFetchAnswerSteaming( export function handleSearch( question: string, + token: string | null, selectedDocs: Doc | null, conversation_id: string | null, history: Array = [], @@ -185,7 +188,7 @@ export function handleSearch( payload.active_docs = selectedDocs.id as string; payload.retriever = selectedDocs?.retriever as string; return conversationService - .search(payload) + .search(payload, token) .then((response) => response.json()) .then((data) => { return data; @@ -206,11 +209,14 @@ export function handleSearchViaApiKey( }; }); return conversationService - .search({ - question: question, - history: JSON.stringify(history), - api_key: api_key, - }) + .search( + { + question: question, + history: JSON.stringify(history), + api_key: api_key, + }, + null, + ) .then((response) => response.json()) .then((data) => { return data; @@ -224,15 +230,19 @@ export function handleSendFeedback( feedback: FEEDBACK, conversation_id: string, prompt_index: number, + token: string | null, ) { return conversationService - .feedback({ - question: prompt, - answer: response, - feedback: feedback, - conversation_id: conversation_id, - question_index: prompt_index, - }) + .feedback( + { + question: prompt, + answer: response, + feedback: feedback, + conversation_id: conversation_id, + question_index: prompt_index, + }, + token, + ) .then((response) => { if (response.ok) { return Promise.resolve(); @@ -265,7 +275,7 @@ export function handleFetchSharedAnswerStreaming( //for shared conversations save_conversation: false, }; conversationService - .answerStream(payload, signal) + .answerStream(payload, null, signal) .then((response) => { if (!response.body) throw Error('No response body'); @@ -339,6 +349,7 @@ export function handleFetchSharedAnswer( question: question, api_key: apiKey, }, + null, signal, ) .then((response) => { diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 4473d3a9..651ead04 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -42,6 +42,7 @@ export const fetchAnswer = createAsyncThunk< await handleFetchAnswerSteaming( question, signal, + state.preference.token, state.preference.selectedDocs!, state.conversation.queries, state.conversation.conversationId, @@ -53,7 +54,7 @@ export const fetchAnswer = createAsyncThunk< if (data.type === 'end') { dispatch(conversationSlice.actions.setStatus('idle')); - getConversations() + getConversations(state.preference.token) .then((fetchedConversations) => { dispatch(setConversations(fetchedConversations)); }) @@ -114,6 +115,7 @@ export const fetchAnswer = createAsyncThunk< const answer = await handleFetchAnswer( question, signal, + state.preference.token, state.preference.selectedDocs!, state.conversation.queries, state.conversation.conversationId, @@ -150,7 +152,7 @@ export const fetchAnswer = createAsyncThunk< }), ); dispatch(conversationSlice.actions.setStatus('idle')); - getConversations() + getConversations(state.preference.token) .then((fetchedConversations) => { dispatch(setConversations(fetchedConversations)); }) diff --git a/frontend/src/modals/JWTModal.tsx b/frontend/src/modals/JWTModal.tsx new file mode 100644 index 00000000..90a13254 --- /dev/null +++ b/frontend/src/modals/JWTModal.tsx @@ -0,0 +1,54 @@ +import React, { useState } from 'react'; +import { useDispatch } from 'react-redux'; + +import Input from '../components/Input'; +import { ActiveState } from '../models/misc'; +import { setToken } from '../preferences/preferenceSlice'; +import WrapperModal from './WrapperModal'; + +type JWTModalProps = { + modalState: ActiveState; + setModalState: (state: ActiveState) => void; +}; + +export default function JWTModal({ modalState, setModalState }: JWTModalProps) { + const dispatch = useDispatch(); + const [jwtToken, setJwtToken] = useState(''); + + const handleSaveToken = () => { + if (jwtToken) { + localStorage.setItem('authToken', jwtToken); + dispatch(setToken(jwtToken)); + setModalState('INACTIVE'); + } + }; + + if (modalState !== 'ACTIVE') return null; + + return ( + setModalState('INACTIVE')} className="p-4"> +
+ + Add JWT Token + +
+
+ setJwtToken(e.target.value)} + borderVariant="thin" + /> +
+ +
+ ); +} diff --git a/frontend/src/preferences/preferenceApi.ts b/frontend/src/preferences/preferenceApi.ts index 8d21bdcd..d52580e0 100644 --- a/frontend/src/preferences/preferenceApi.ts +++ b/frontend/src/preferences/preferenceApi.ts @@ -3,9 +3,9 @@ import userService from '../api/services/userService'; import { Doc, GetDocsResponse } from '../models/misc'; //Fetches all JSON objects from the source. We only use the objects with the "model" property in SelectDocsModal.tsx. Hopefully can clean up the source file later. -export async function getDocs(): Promise { +export async function getDocs(token: string | null): Promise { try { - const response = await userService.getDocs(); + const response = await userService.getDocs(token); const data = await response.json(); const docs: Doc[] = []; @@ -26,10 +26,11 @@ export async function getDocsWithPagination( pageNumber = 1, rowsPerPage = 10, searchTerm = '', + token: string | null, ): Promise { try { const query = `sort=${sort}&order=${order}&page=${pageNumber}&rows=${rowsPerPage}&search=${searchTerm}`; - const response = await userService.getDocsWithPagination(query); + const response = await userService.getDocsWithPagination(query, token); const data = await response.json(); const docs: Doc[] = []; Array.isArray(data.paginated) && @@ -48,12 +49,12 @@ export async function getDocsWithPagination( } } -export async function getConversations(): Promise<{ +export async function getConversations(token: string | null): Promise<{ data: { name: string; id: string }[] | null; loading: boolean; }> { try { - const response = await conversationService.getConversations(); + const response = await conversationService.getConversations(token); const data = await response.json(); const conversations: { name: string; id: string }[] = []; @@ -100,8 +101,11 @@ export function setLocalRecentDocs(doc: Doc | null): void { docPath = 'local' + '/' + doc.name + '/'; } userService - .checkDocs({ - docs: docPath, - }) + .checkDocs( + { + docs: docPath, + }, + null, + ) .then((response) => response.json()); } diff --git a/frontend/src/preferences/preferenceSlice.ts b/frontend/src/preferences/preferenceSlice.ts index 8b3064d5..4bca1a37 100644 --- a/frontend/src/preferences/preferenceSlice.ts +++ b/frontend/src/preferences/preferenceSlice.ts @@ -19,6 +19,7 @@ export interface Preference { data: { name: string; id: string }[] | null; loading: boolean; }; + token: string | null; modalState: ActiveState; paginatedDocuments: Doc[] | null; } @@ -42,6 +43,7 @@ const initialState: Preference = { data: null, loading: false, }, + token: localStorage.getItem('authToken') || null, modalState: 'INACTIVE', paginatedDocuments: null, }; @@ -65,6 +67,9 @@ export const prefSlice = createSlice({ setConversations: (state, action) => { state.conversations = action.payload; }, + setToken: (state, action) => { + state.token = action.payload; + }, setPrompt: (state, action) => { state.prompt = action.payload; }, @@ -85,6 +90,7 @@ export const { setSelectedDocs, setSourceDocs, setConversations, + setToken, setPrompt, setChunks, setTokenLimit, @@ -157,6 +163,7 @@ export const selectConversations = (state: RootState) => state.preference.conversations; export const selectConversationId = (state: RootState) => state.conversation.conversationId; +export const selectToken = (state: RootState) => state.preference.token; export const selectPrompt = (state: RootState) => state.preference.prompt; export const selectChunks = (state: RootState) => state.preference.chunks; export const selectTokenLimit = (state: RootState) => diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 8f426ed6..02aa9a68 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -16,6 +16,7 @@ const doc = localStorage.getItem('DocsGPTRecentDocs'); const preloadedState: { preference: Preference } = { preference: { apiKey: key ?? '', + token: localStorage.getItem('authToken') ?? null, prompt: prompt !== null ? JSON.parse(prompt)