refactor: answer routes to comply with OpenAPI spec using flask-restx

This commit is contained in:
Siddhant Rai
2024-09-30 00:41:34 +05:30
parent b084e3074d
commit e8988e82d0
5 changed files with 447 additions and 347 deletions

View File

@@ -7,13 +7,15 @@ from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, jsonify, make_response, request
from flask_restx import Api, fields, Namespace, Resource
from flask_restx import fields, Namespace, Resource
from pymongo import MongoClient
from werkzeug.utils import secure_filename
from application.api.user.tasks import ingest, ingest_remote
from application.core.settings import settings
from application.extensions import api
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
mongo = MongoClient(settings.MONGO_URI)
@@ -28,14 +30,8 @@ shared_conversations_collections = db["shared_conversations"]
user_logs_collection = db["user_logs"]
user = Blueprint("user", __name__)
api = Api(
user,
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
default="user",
default_label="User operations",
)
user_ns = Namespace("user", description="User related operations", path="/")
api.add_namespace(user_ns)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -63,22 +59,7 @@ def generate_date_range(start_date, end_date):
}
def check_required_fields(data, required_fields):
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Missing fields: {', '.join(missing_fields)}",
}
),
400,
)
return None
@api.route("/api/delete_conversation")
@user_ns.route("/api/delete_conversation")
class DeleteConversation(Resource):
@api.doc(
description="Deletes a conversation by ID",
@@ -98,7 +79,7 @@ class DeleteConversation(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/delete_all_conversations")
@user_ns.route("/api/delete_all_conversations")
class DeleteAllConversations(Resource):
@api.doc(
description="Deletes all conversations for a specific user",
@@ -112,7 +93,7 @@ class DeleteAllConversations(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/get_conversations")
@user_ns.route("/api/get_conversations")
class GetConversations(Resource):
@api.doc(
description="Retrieve a list of the latest 30 conversations",
@@ -129,7 +110,7 @@ class GetConversations(Resource):
return make_response(jsonify(list_conversations), 200)
@api.route("/api/get_single_conversation")
@user_ns.route("/api/get_single_conversation")
class GetSingleConversation(Resource):
@api.doc(
description="Retrieve a single conversation by ID",
@@ -153,7 +134,7 @@ class GetSingleConversation(Resource):
return make_response(jsonify(conversation["queries"]), 200)
@api.route("/api/update_conversation_name")
@user_ns.route("/api/update_conversation_name")
class UpdateConversationName(Resource):
@api.expect(
api.model(
@@ -186,7 +167,7 @@ class UpdateConversationName(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/feedback")
@user_ns.route("/api/feedback")
class SubmitFeedback(Resource):
@api.expect(
api.model(
@@ -229,7 +210,7 @@ class SubmitFeedback(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/delete_by_ids")
@user_ns.route("/api/delete_by_ids")
class DeleteByIds(Resource):
@api.doc(
description="Deletes documents from the vector store by IDs",
@@ -252,7 +233,7 @@ class DeleteByIds(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@api.route("/api/delete_old")
@user_ns.route("/api/delete_old")
class DeleteOldIndexes(Resource):
@api.doc(
description="Deletes old indexes",
@@ -289,7 +270,7 @@ class DeleteOldIndexes(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/upload")
@user_ns.route("/api/upload")
class UploadFile(Resource):
@api.expect(
api.model(
@@ -370,7 +351,7 @@ class UploadFile(Resource):
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@api.route("/api/remote")
@user_ns.route("/api/remote")
class UploadRemote(Resource):
@api.expect(
api.model(
@@ -408,7 +389,7 @@ class UploadRemote(Resource):
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@api.route("/api/task_status")
@user_ns.route("/api/task_status")
class TaskStatus(Resource):
task_status_model = api.model(
"TaskStatusModel",
@@ -435,7 +416,7 @@ class TaskStatus(Resource):
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
@api.route("/api/combine")
@user_ns.route("/api/combine")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
def get(self):
@@ -496,7 +477,7 @@ class CombinedJson(Resource):
return make_response(jsonify(data), 200)
@api.route("/api/docs_check")
@user_ns.route("/api/docs_check")
class CheckDocs(Resource):
check_docs_model = api.model(
"CheckDocsModel",
@@ -522,7 +503,7 @@ class CheckDocs(Resource):
return make_response(jsonify({"status": "not found"}), 404)
@api.route("/api/create_prompt")
@user_ns.route("/api/create_prompt")
class CreatePrompt(Resource):
create_prompt_model = api.model(
"CreatePromptModel",
@@ -560,7 +541,7 @@ class CreatePrompt(Resource):
return make_response(jsonify({"id": new_id}), 200)
@api.route("/api/get_prompts")
@user_ns.route("/api/get_prompts")
class GetPrompts(Resource):
@api.doc(description="Get all prompts for the user")
def get(self):
@@ -587,7 +568,7 @@ class GetPrompts(Resource):
return make_response(jsonify(list_prompts), 200)
@api.route("/api/get_single_prompt")
@user_ns.route("/api/get_single_prompt")
class GetSinglePrompt(Resource):
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
def get(self):
@@ -628,7 +609,7 @@ class GetSinglePrompt(Resource):
return make_response(jsonify({"content": prompt["content"]}), 200)
@api.route("/api/delete_prompt")
@user_ns.route("/api/delete_prompt")
class DeletePrompt(Resource):
delete_prompt_model = api.model(
"DeletePromptModel",
@@ -652,7 +633,7 @@ class DeletePrompt(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/update_prompt")
@user_ns.route("/api/update_prompt")
class UpdatePrompt(Resource):
update_prompt_model = api.model(
"UpdatePromptModel",
@@ -685,7 +666,7 @@ class UpdatePrompt(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/get_api_keys")
@user_ns.route("/api/get_api_keys")
class GetApiKeys(Resource):
@api.doc(description="Retrieve API keys for the user")
def get(self):
@@ -719,7 +700,7 @@ class GetApiKeys(Resource):
return make_response(jsonify(list_keys), 200)
@api.route("/api/create_api_key")
@user_ns.route("/api/create_api_key")
class CreateApiKey(Resource):
create_api_key_model = api.model(
"CreateApiKeyModel",
@@ -764,7 +745,7 @@ class CreateApiKey(Resource):
return make_response(jsonify({"id": new_id, "key": key}), 201)
@api.route("/api/delete_api_key")
@user_ns.route("/api/delete_api_key")
class DeleteApiKey(Resource):
delete_api_key_model = api.model(
"DeleteApiKeyModel",
@@ -790,7 +771,7 @@ class DeleteApiKey(Resource):
return {"success": True}, 200
@api.route("/api/share")
@user_ns.route("/api/share")
class ShareConversation(Resource):
share_conversation_model = api.model(
"ShareConversationModel",
@@ -988,7 +969,7 @@ class ShareConversation(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@api.route("/api/shared_conversation/<string:identifier>")
@user_ns.route("/api/shared_conversation/<string:identifier>")
class GetPubliclySharedConversations(Resource):
@api.doc(description="Get publicly shared conversations by identifier")
def get(self, identifier: str):
@@ -1043,7 +1024,7 @@ class GetPubliclySharedConversations(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@api.route("/api/get_message_analytics")
@user_ns.route("/api/get_message_analytics")
class GetMessageAnalytics(Resource):
get_message_analytics_model = api.model(
"GetMessageAnalyticsModel",
@@ -1181,7 +1162,7 @@ class GetMessageAnalytics(Resource):
)
@api.route("/api/get_token_analytics")
@user_ns.route("/api/get_token_analytics")
class GetTokenAnalytics(Resource):
get_token_analytics_model = api.model(
"GetTokenAnalyticsModel",
@@ -1332,7 +1313,7 @@ class GetTokenAnalytics(Resource):
)
@api.route("/api/get_feedback_analytics")
@user_ns.route("/api/get_feedback_analytics")
class GetFeedbackAnalytics(Resource):
get_feedback_analytics_model = api.model(
"GetFeedbackAnalyticsModel",
@@ -1550,7 +1531,7 @@ class GetFeedbackAnalytics(Resource):
)
@api.route("/api/get_user_logs")
@user_ns.route("/api/get_user_logs")
class GetUserLogs(Resource):
get_user_logs_model = api.model(
"GetUserLogsModel",
@@ -1629,7 +1610,7 @@ class GetUserLogs(Resource):
)
@api.route("/api/manage_sync")
@user_ns.route("/api/manage_sync")
class ManageSync(Resource):
manage_sync_model = api.model(
"ManageSyncModel",