feat: implement JWT authentication and token management in frontend and backend

This commit is contained in:
Siddhant Rai
2025-03-14 17:07:15 +05:30
parent fe02bf9347
commit 7fd377bdbe
17 changed files with 453 additions and 178 deletions

View File

@@ -122,6 +122,7 @@ def save_conversation(
source_log_docs,
tool_calls,
llm,
decoded_token,
index=None,
api_key=None,
):
@@ -180,7 +181,7 @@ def save_conversation(
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_data = {
"user": "local",
"user": decoded_token.get("sub"),
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
@@ -221,6 +222,7 @@ def complete_stream(
retriever,
conversation_id,
user_api_key,
decoded_token,
isNoneDoc=False,
index=None,
should_save_conversation=True,
@@ -271,6 +273,7 @@ def complete_stream(
source_log_docs,
tool_calls,
llm,
decoded_token,
index,
api_key=user_api_key,
)
@@ -286,7 +289,7 @@ def complete_stream(
{
"action": "stream_answer",
"level": "info",
"user": "local",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
@@ -381,15 +384,21 @@ class Stream(Resource):
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return bad_request(401, "Unauthorized")
logger.info(
f"/stream - request_data: {data}, source: {source}",
@@ -429,6 +438,7 @@ class Stream(Resource):
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=index,
should_save_conversation=save_conv,
@@ -521,13 +531,21 @@ class Answer(Resource):
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return bad_request(401, "Unauthorized")
prompt = get_prompt(prompt_id)
@@ -614,6 +632,7 @@ class Answer(Resource):
source_log_docs,
tool_calls,
llm,
decoded_token,
api_key=user_api_key,
)
)
@@ -623,7 +642,7 @@ class Answer(Resource):
{
"action": "api_answer",
"level": "info",
"user": "local",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
@@ -692,12 +711,17 @@ class Search(Resource):
chunks = int(data_key.get("chunks", 2))
source = {"active_docs": data_key.get("source")}
user_api_key = data["api_key"]
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
logger.info(
f"/api/answer - request_data: {data}, source: {source}",
@@ -723,7 +747,7 @@ class Search(Resource):
{
"action": "api_search",
"level": "info",
"user": "local",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"sources": docs,

View File

@@ -15,7 +15,6 @@ from werkzeug.utils import secure_filename
from application.agents.tools.tool_manager import ToolManager
from application.api.user.tasks import ingest, ingest_remote
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.extensions import api
@@ -110,11 +109,18 @@ class GetConversations(Resource):
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
conversations = conversations_collection.find(
{"api_key": {"$exists": False}}
).sort("date", -1).limit(30)
conversations = (
conversations_collection.find(
{"api_key": {"$exists": False}, "user": decoded_token.get("sub")}
)
.sort("date", -1)
.limit(30)
)
list_conversations = [
{"id": str(conversation["_id"]), "name": conversation["name"]}
for conversation in conversations
@@ -132,6 +138,9 @@ class GetSingleConversation(Resource):
params={"id": "The conversation ID"},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
@@ -140,7 +149,7 @@ class GetSingleConversation(Resource):
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
@@ -227,7 +236,7 @@ class SubmitFeedback(Resource):
{
"$unset": {
f"queries.{data['question_index']}.feedback": "",
f"queries.{data['question_index']}.feedback_timestamp": ""
f"queries.{data['question_index']}.feedback_timestamp": "",
}
},
)
@@ -240,8 +249,12 @@ class SubmitFeedback(Resource):
},
{
"$set": {
f"queries.{data['question_index']}.feedback": data["feedback"],
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(datetime.timezone.utc)
f"queries.{data['question_index']}.feedback": data[
"feedback"
],
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
datetime.timezone.utc
),
}
},
)
@@ -1211,7 +1224,13 @@ class GetMessageAnalytics(Resource):
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=["last_hour", "last_24_hour", "last_7_days", "last_15_days", "last_30_days"],
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@@ -1244,9 +1263,9 @@ class GetMessageAnalytics(Resource):
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6 if filter_option == "last_7_days"
else 14 if filter_option == "last_15_days"
else 29
6
if filter_option == "last_7_days"
else 14 if filter_option == "last_15_days" else 29
)
else:
return make_response(
@@ -1254,25 +1273,20 @@ class GetMessageAnalytics(Resource):
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
try:
pipeline = [
# Initial match for API key if provided
{
"$match": {
"api_key": api_key if api_key else {"$exists": False}
}
},
{"$match": {"api_key": api_key if api_key else {"$exists": False}}},
{"$unwind": "$queries"},
# Match queries within the time range
{
"$match": {
"queries.timestamp": {
"$gte": start_date,
"$lte": end_date
}
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
}
},
# Group by formatted timestamp
@@ -1281,14 +1295,14 @@ class GetMessageAnalytics(Resource):
"_id": {
"$dateToString": {
"format": group_format,
"date": "$queries.timestamp"
"date": "$queries.timestamp",
}
},
"count": {"$sum": 1}
"count": {"$sum": 1},
}
},
# Sort by timestamp
{"$sort": {"_id": 1}}
{"$sort": {"_id": 1}},
]
message_data = conversations_collection.aggregate(pipeline)
@@ -1511,11 +1525,21 @@ class GetFeedbackAnalytics(Resource):
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
@@ -1533,13 +1557,21 @@ class GetFeedbackAnalytics(Resource):
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
date_field = {"$dateToString": {"format": group_format, "date": "$queries.feedback_timestamp"}}
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
try:
match_stage = {
"$match": {
"queries.feedback_timestamp": {"$gte": start_date, "$lte": end_date},
"queries.feedback": {"$exists": True}
"queries.feedback_timestamp": {
"$gte": start_date,
"$lte": end_date,
},
"queries.feedback": {"$exists": True},
}
}
if api_key:

View File

@@ -1,20 +1,26 @@
import platform
import dotenv
from flask import Flask, redirect, request
from flask import Flask, jsonify, redirect, request
from jose import jwt
from application.auth import get_or_create_user_id, handle_auth
from application.core.logging_config import setup_logging
setup_logging()
from application.api.answer.routes import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.extensions import api # noqa: E402
from application.api.answer.routes import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.extensions import api # noqa: E402
if platform.system() == "Windows":
import pathlib
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
@@ -32,6 +38,13 @@ app.config.update(
celery.config_from_object("application.celeryconfig")
api.init_app(app)
SIMPLE_JWT_TOKEN = None
if settings.AUTH_TYPE == "simple_jwt":
user_id = get_or_create_user_id()
payload = {"sub": user_id}
SIMPLE_JWT_TOKEN = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm="HS256")
print(f"Generated Simple JWT Token: {SIMPLE_JWT_TOKEN}")
@app.route("/")
def home():
@@ -41,11 +54,27 @@ def home():
return "Welcome to DocsGPT Backend!"
@app.before_request
def authenticate_request():
if request.method == "OPTIONS":
return "", 200
decoded_token = handle_auth(request)
if "message" in decoded_token:
request.decoded_token = None
elif "error" in decoded_token:
return jsonify(decoded_token), 401
else:
request.decoded_token = decoded_token
@app.after_request
def after_request(response):
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization")
response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS")
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
response.headers.add(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
)
return response

39
application/auth.py Normal file
View File

@@ -0,0 +1,39 @@
import uuid
from jose import jwt
from application.core.settings import settings
def handle_auth(request, data={}):
if settings.AUTH_TYPE == "simple_jwt":
jwt_token = request.headers.get("Authorization")
if not jwt_token:
return {"message": "Missing Authorization header"}
jwt_token = jwt_token.replace("Bearer ", "")
try:
decoded_token = jwt.decode(
jwt_token,
settings.JWT_SECRET_KEY,
algorithms=["HS256"],
options={"verify_exp": False},
)
return decoded_token
except Exception as e:
return {"message": f"Authentication error: {str(e)}"}
else:
return {"sub": "local"}
def get_or_create_user_id():
try:
with open(settings.USER_ID_FILE, "r") as f:
user_id = f.read().strip()
return user_id
except FileNotFoundError:
user_id = str(uuid.uuid4())
with open(settings.USER_ID_FILE, "w") as f:
f.write(user_id)
return user_id

View File

@@ -10,6 +10,7 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = "simple_jwt"
LLM_NAME: str = "docsgpt"
MODEL_NAME: Optional[str] = (
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
@@ -98,6 +99,9 @@ class Settings(BaseSettings):
FLASK_DEBUG_MODE: bool = False
JWT_SECRET_KEY: str = ""
USER_ID_FILE: str = os.path.join(current_dir, "user_id.txt")
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -69,6 +69,7 @@ pymongo==4.10.1
pypdf==5.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-jose==3.4.0
python-pptx==1.0.2
qdrant-client==1.13.2
redis==5.2.1