refactor: answer routes to comply with OpenAPI spec using flask-restx

This commit is contained in:
Siddhant Rai
2024-09-30 00:41:34 +05:30
parent b084e3074d
commit e8988e82d0
5 changed files with 447 additions and 347 deletions

View File

@@ -1,20 +1,24 @@
import asyncio
import datetime
import json
import logging
import os
import sys
from flask import Blueprint, request, Response, current_app
import json
import datetime
import logging
import traceback
from pymongo import MongoClient
from bson.objectid import ObjectId
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, current_app, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from pymongo import MongoClient
from application.core.settings import settings
from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -25,7 +29,10 @@ sources_collection = db["sources"]
prompts_collection = db["prompts"]
api_key_collection = db["api_keys"]
user_logs_collection = db["user_logs"]
answer = Blueprint("answer", __name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
@@ -186,10 +193,10 @@ def complete_stream(
answer = retriever.gen()
sources = retriever.search()
for source in sources:
if("text" in source):
source["text"] = source["text"][:100].strip()+"..."
if(len(sources) > 0):
data = json.dumps({"type":"source","source":sources})
if "text" in source:
source["text"] = source["text"][:100].strip() + "..."
if len(sources) > 0:
data = json.dumps({"type": "source", "source": sources})
yield f"data: {data}\n\n"
for line in answer:
if "answer" in line:
@@ -243,109 +250,133 @@ def complete_stream(
return
@answer.route("/stream", methods=["POST"])
def stream():
try:
data = request.get_json()
question = data["question"]
if "history" not in data:
history = []
else:
history = data["history"]
history = json.loads(history)
if "conversation_id" not in data:
conversation_id = None
else:
conversation_id = data["conversation_id"]
if "prompt_id" in data:
prompt_id = data["prompt_id"]
else:
prompt_id = "default"
if "selectedDocs" in data and data["selectedDocs"] is None:
chunks = 0
elif "chunks" in data:
chunks = int(data["chunks"])
else:
chunks = 2
if "token_limit" in data:
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
## retriever can be "brave_search, duckduck_search or classic"
retriever_name = data["retriever"] if "retriever" in data else "classic"
# check if active_docs or api_key is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = {"active_docs": data_key["source"]}
retriever_name = data_key["retriever"] or retriever_name
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}
user_api_key = None
current_app.logger.info(
f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
@answer_ns.route("/stream")
class Stream(Resource):
stream_model = api.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
mimetype="text/event-stream",
)
"history": fields.List(
fields.String, required=False, description="Chat history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"selectedDocs": fields.String(
required=False, description="Selected documents"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
except ValueError:
message = "Malformed request body"
print("\033[91merr", str(message), file=sys.stderr)
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
current_app.logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
message = e.args[0]
status_code = 400
# # Custom exceptions with two arguments, index 1 as status code
if len(e.args) >= 2:
status_code = e.args[1]
return Response(
error_stream_generate(message),
status=status_code,
mimetype="text/event-stream",
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = data.get("history", [])
history = json.loads(history)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
if "selectedDocs" in data and data["selectedDocs"] is None:
chunks = 0
else:
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}
user_api_key = None
current_app.logger.info(
f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
),
mimetype="text/event-stream",
)
except ValueError:
message = "Malformed request body"
print("\033[91merr", str(message), file=sys.stderr)
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
current_app.logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
message = e.args[0]
status_code = 400
# Custom exceptions with two arguments, index 1 as status code
if len(e.args) >= 2:
status_code = e.args[1]
return Response(
error_stream_generate(message),
status=status_code,
mimetype="text/event-stream",
)
def error_stream_generate(err_response):
@@ -353,180 +384,235 @@ def error_stream_generate(err_response):
yield f"data: {data}\n\n"
@answer.route("/api/answer", methods=["POST"])
def api_answer():
data = request.get_json()
question = data["question"]
if "history" not in data:
history = []
else:
history = data["history"]
if "conversation_id" not in data:
conversation_id = None
else:
conversation_id = data["conversation_id"]
print("-" * 5)
if "prompt_id" in data:
prompt_id = data["prompt_id"]
else:
prompt_id = "default"
if "chunks" in data:
chunks = int(data["chunks"])
else:
chunks = 2
if "token_limit" in data:
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
## retriever can be brave_search, duckduck_search or classic
retriever_name = data["retriever"] if "retriever" in data else "classic"
# use try and except to check for exception
try:
# check if the vectorstore is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = {"active_docs": data_key["source"]}
retriever_name = data_key["retriever"] or retriever_name
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}
user_api_key = None
prompt = get_prompt(prompt_id)
current_app.logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_answer",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
return result
except Exception as e:
current_app.logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
@answer.route("/api/search", methods=["POST"])
def api_search():
data = request.get_json()
question = data["question"]
if "chunks" in data:
chunks = int(data["chunks"])
else:
chunks = 2
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
source = {"active_docs":data_key["source"]}
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
else:
source = {}
user_api_key = None
if "retriever" in data:
retriever_name = data["retriever"]
else:
retriever_name = "classic"
if "token_limit" in data:
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
current_app.logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
docs = retriever.search()
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
@answer_ns.route("/api/answer")
class Answer(Resource):
answer_model = api.model(
"AnswerModel",
{
"action": "api_search",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"sources": docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
"question": fields.String(
required=True, description="The question to answer"
),
"history": fields.List(
fields.String, required=False, description="Conversation history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
@api.expect(answer_model)
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
return docs
try:
question = data["question"]
history = data.get("history", [])
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}
user_api_key = None
prompt = get_prompt(prompt_id)
current_app.logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_answer",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
except Exception as e:
current_app.logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(result, 200)
@answer_ns.route("/api/search")
class Search(Resource):
search_model = api.model(
"SearchModel",
{
"question": fields.String(
required=True, description="The question to search"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"api_key": fields.String(
required=False, description="API key for authentication"
),
"active_docs": fields.String(
required=False, description="Active documents for retrieval"
),
"retriever": fields.String(required=False, description="Retriever type"),
"token_limit": fields.Integer(
required=False, description="Limit for tokens"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(search_model)
@api.doc(
description="Search for relevant documents based on the question and retriever"
)
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
source = {"active_docs": data_key.get("source")}
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
else:
source = {}
user_api_key = None
current_app.logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
docs = retriever.search()
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_search",
"level": "info",
"user": "local",
"api_key": user_api_key,
"question": question,
"sources": docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
except Exception as e:
current_app.logger.error(
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(docs, 200)

View File

@@ -7,13 +7,15 @@ from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, jsonify, make_response, request
from flask_restx import Api, fields, Namespace, Resource
from flask_restx import fields, Namespace, Resource
from pymongo import MongoClient
from werkzeug.utils import secure_filename
from application.api.user.tasks import ingest, ingest_remote
from application.core.settings import settings
from application.extensions import api
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
mongo = MongoClient(settings.MONGO_URI)
@@ -28,14 +30,8 @@ shared_conversations_collections = db["shared_conversations"]
user_logs_collection = db["user_logs"]
user = Blueprint("user", __name__)
api = Api(
user,
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
default="user",
default_label="User operations",
)
user_ns = Namespace("user", description="User related operations", path="/")
api.add_namespace(user_ns)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -63,22 +59,7 @@ def generate_date_range(start_date, end_date):
}
def check_required_fields(data, required_fields):
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Missing fields: {', '.join(missing_fields)}",
}
),
400,
)
return None
@api.route("/api/delete_conversation")
@user_ns.route("/api/delete_conversation")
class DeleteConversation(Resource):
@api.doc(
description="Deletes a conversation by ID",
@@ -98,7 +79,7 @@ class DeleteConversation(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/delete_all_conversations")
@user_ns.route("/api/delete_all_conversations")
class DeleteAllConversations(Resource):
@api.doc(
description="Deletes all conversations for a specific user",
@@ -112,7 +93,7 @@ class DeleteAllConversations(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/get_conversations")
@user_ns.route("/api/get_conversations")
class GetConversations(Resource):
@api.doc(
description="Retrieve a list of the latest 30 conversations",
@@ -129,7 +110,7 @@ class GetConversations(Resource):
return make_response(jsonify(list_conversations), 200)
@api.route("/api/get_single_conversation")
@user_ns.route("/api/get_single_conversation")
class GetSingleConversation(Resource):
@api.doc(
description="Retrieve a single conversation by ID",
@@ -153,7 +134,7 @@ class GetSingleConversation(Resource):
return make_response(jsonify(conversation["queries"]), 200)
@api.route("/api/update_conversation_name")
@user_ns.route("/api/update_conversation_name")
class UpdateConversationName(Resource):
@api.expect(
api.model(
@@ -186,7 +167,7 @@ class UpdateConversationName(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/feedback")
@user_ns.route("/api/feedback")
class SubmitFeedback(Resource):
@api.expect(
api.model(
@@ -229,7 +210,7 @@ class SubmitFeedback(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/delete_by_ids")
@user_ns.route("/api/delete_by_ids")
class DeleteByIds(Resource):
@api.doc(
description="Deletes documents from the vector store by IDs",
@@ -252,7 +233,7 @@ class DeleteByIds(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@api.route("/api/delete_old")
@user_ns.route("/api/delete_old")
class DeleteOldIndexes(Resource):
@api.doc(
description="Deletes old indexes",
@@ -289,7 +270,7 @@ class DeleteOldIndexes(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/upload")
@user_ns.route("/api/upload")
class UploadFile(Resource):
@api.expect(
api.model(
@@ -370,7 +351,7 @@ class UploadFile(Resource):
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@api.route("/api/remote")
@user_ns.route("/api/remote")
class UploadRemote(Resource):
@api.expect(
api.model(
@@ -408,7 +389,7 @@ class UploadRemote(Resource):
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@api.route("/api/task_status")
@user_ns.route("/api/task_status")
class TaskStatus(Resource):
task_status_model = api.model(
"TaskStatusModel",
@@ -435,7 +416,7 @@ class TaskStatus(Resource):
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
@api.route("/api/combine")
@user_ns.route("/api/combine")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
def get(self):
@@ -496,7 +477,7 @@ class CombinedJson(Resource):
return make_response(jsonify(data), 200)
@api.route("/api/docs_check")
@user_ns.route("/api/docs_check")
class CheckDocs(Resource):
check_docs_model = api.model(
"CheckDocsModel",
@@ -522,7 +503,7 @@ class CheckDocs(Resource):
return make_response(jsonify({"status": "not found"}), 404)
@api.route("/api/create_prompt")
@user_ns.route("/api/create_prompt")
class CreatePrompt(Resource):
create_prompt_model = api.model(
"CreatePromptModel",
@@ -560,7 +541,7 @@ class CreatePrompt(Resource):
return make_response(jsonify({"id": new_id}), 200)
@api.route("/api/get_prompts")
@user_ns.route("/api/get_prompts")
class GetPrompts(Resource):
@api.doc(description="Get all prompts for the user")
def get(self):
@@ -587,7 +568,7 @@ class GetPrompts(Resource):
return make_response(jsonify(list_prompts), 200)
@api.route("/api/get_single_prompt")
@user_ns.route("/api/get_single_prompt")
class GetSinglePrompt(Resource):
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
def get(self):
@@ -628,7 +609,7 @@ class GetSinglePrompt(Resource):
return make_response(jsonify({"content": prompt["content"]}), 200)
@api.route("/api/delete_prompt")
@user_ns.route("/api/delete_prompt")
class DeletePrompt(Resource):
delete_prompt_model = api.model(
"DeletePromptModel",
@@ -652,7 +633,7 @@ class DeletePrompt(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/update_prompt")
@user_ns.route("/api/update_prompt")
class UpdatePrompt(Resource):
update_prompt_model = api.model(
"UpdatePromptModel",
@@ -685,7 +666,7 @@ class UpdatePrompt(Resource):
return make_response(jsonify({"success": True}), 200)
@api.route("/api/get_api_keys")
@user_ns.route("/api/get_api_keys")
class GetApiKeys(Resource):
@api.doc(description="Retrieve API keys for the user")
def get(self):
@@ -719,7 +700,7 @@ class GetApiKeys(Resource):
return make_response(jsonify(list_keys), 200)
@api.route("/api/create_api_key")
@user_ns.route("/api/create_api_key")
class CreateApiKey(Resource):
create_api_key_model = api.model(
"CreateApiKeyModel",
@@ -764,7 +745,7 @@ class CreateApiKey(Resource):
return make_response(jsonify({"id": new_id, "key": key}), 201)
@api.route("/api/delete_api_key")
@user_ns.route("/api/delete_api_key")
class DeleteApiKey(Resource):
delete_api_key_model = api.model(
"DeleteApiKeyModel",
@@ -790,7 +771,7 @@ class DeleteApiKey(Resource):
return {"success": True}, 200
@api.route("/api/share")
@user_ns.route("/api/share")
class ShareConversation(Resource):
share_conversation_model = api.model(
"ShareConversationModel",
@@ -988,7 +969,7 @@ class ShareConversation(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@api.route("/api/shared_conversation/<string:identifier>")
@user_ns.route("/api/shared_conversation/<string:identifier>")
class GetPubliclySharedConversations(Resource):
@api.doc(description="Get publicly shared conversations by identifier")
def get(self, identifier: str):
@@ -1043,7 +1024,7 @@ class GetPubliclySharedConversations(Resource):
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@api.route("/api/get_message_analytics")
@user_ns.route("/api/get_message_analytics")
class GetMessageAnalytics(Resource):
get_message_analytics_model = api.model(
"GetMessageAnalyticsModel",
@@ -1181,7 +1162,7 @@ class GetMessageAnalytics(Resource):
)
@api.route("/api/get_token_analytics")
@user_ns.route("/api/get_token_analytics")
class GetTokenAnalytics(Resource):
get_token_analytics_model = api.model(
"GetTokenAnalyticsModel",
@@ -1332,7 +1313,7 @@ class GetTokenAnalytics(Resource):
)
@api.route("/api/get_feedback_analytics")
@user_ns.route("/api/get_feedback_analytics")
class GetFeedbackAnalytics(Resource):
get_feedback_analytics_model = api.model(
"GetFeedbackAnalyticsModel",
@@ -1550,7 +1531,7 @@ class GetFeedbackAnalytics(Resource):
)
@api.route("/api/get_user_logs")
@user_ns.route("/api/get_user_logs")
class GetUserLogs(Resource):
get_user_logs_model = api.model(
"GetUserLogsModel",
@@ -1629,7 +1610,7 @@ class GetUserLogs(Resource):
)
@api.route("/api/manage_sync")
@user_ns.route("/api/manage_sync")
class ManageSync(Resource):
manage_sync_model = api.model(
"ManageSyncModel",

View File

@@ -1,15 +1,19 @@
import platform
import dotenv
from application.celery_init import celery
from flask import Flask, request, redirect
from application.core.settings import settings
from application.api.user.routes import user
from flask import Flask, redirect, request
from application.api.answer.routes import answer
from application.api.internal.routes import internal
from application.api.user.routes import user
from application.celery_init import celery
from application.core.logging_config import setup_logging
from application.core.settings import settings
from application.extensions import api
if platform.system() == "Windows":
import pathlib
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
@@ -23,16 +27,19 @@ app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
CELERY_RESULT_BACKEND=settings.CELERY_RESULT_BACKEND,
MONGO_URI=settings.MONGO_URI
MONGO_URI=settings.MONGO_URI,
)
celery.config_from_object("application.celeryconfig")
api.init_app(app)
@app.route("/")
def home():
if request.remote_addr in ('0.0.0.0', '127.0.0.1', 'localhost', '172.18.0.1'):
return redirect('http://localhost:5173')
if request.remote_addr in ("0.0.0.0", "127.0.0.1", "localhost", "172.18.0.1"):
return redirect("http://localhost:5173")
else:
return 'Welcome to DocsGPT Backend!'
return "Welcome to DocsGPT Backend!"
@app.after_request
def after_request(response):
@@ -41,6 +48,6 @@ def after_request(response):
response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS")
return response
if __name__ == "__main__":
app.run(debug=settings.FLASK_DEBUG_MODE, port=7091)

View File

@@ -0,0 +1,7 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

View File

@@ -1,22 +1,41 @@
import tiktoken
from flask import jsonify, make_response
_encoding = None
def get_encoding():
global _encoding
if _encoding is None:
_encoding = tiktoken.get_encoding("cl100k_base")
return _encoding
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
num_tokens = len(encoding.encode(string))
return num_tokens
def count_tokens_docs(docs):
docs_content = ""
for doc in docs:
docs_content += doc.page_content
tokens = num_tokens_from_string(docs_content)
return tokens
return tokens
def check_required_fields(data, required_fields):
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Missing fields: {', '.join(missing_fields)}",
}
),
400,
)
return None