mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge pull request #1154 from siiddhantt/feat/openapi-spec
feature: OpenAPI specification compliant endpoints
This commit is contained in:
@@ -1,20 +1,24 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from flask import Blueprint, request, Response, current_app
|
||||
import json
|
||||
import datetime
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from pymongo import MongoClient
|
||||
from bson.objectid import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import Blueprint, current_app, make_response, request, Response
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.error import bad_request
|
||||
from application.extensions import api
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.error import bad_request
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +29,10 @@ sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
api_key_collection = db["api_keys"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
|
||||
answer = Blueprint("answer", __name__)
|
||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||
api.add_namespace(answer_ns)
|
||||
|
||||
gpt_model = ""
|
||||
# to have some kind of default behaviour
|
||||
@@ -186,10 +193,10 @@ def complete_stream(
|
||||
answer = retriever.gen()
|
||||
sources = retriever.search()
|
||||
for source in sources:
|
||||
if("text" in source):
|
||||
source["text"] = source["text"][:100].strip()+"..."
|
||||
if(len(sources) > 0):
|
||||
data = json.dumps({"type":"source","source":sources})
|
||||
if "text" in source:
|
||||
source["text"] = source["text"][:100].strip() + "..."
|
||||
if len(sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": sources})
|
||||
yield f"data: {data}\n\n"
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
@@ -243,109 +250,133 @@ def complete_stream(
|
||||
return
|
||||
|
||||
|
||||
@answer.route("/stream", methods=["POST"])
|
||||
def stream():
|
||||
try:
|
||||
data = request.get_json()
|
||||
question = data["question"]
|
||||
if "history" not in data:
|
||||
history = []
|
||||
else:
|
||||
history = data["history"]
|
||||
history = json.loads(history)
|
||||
if "conversation_id" not in data:
|
||||
conversation_id = None
|
||||
else:
|
||||
conversation_id = data["conversation_id"]
|
||||
if "prompt_id" in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = "default"
|
||||
if "selectedDocs" in data and data["selectedDocs"] is None:
|
||||
chunks = 0
|
||||
elif "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
## retriever can be "brave_search, duckduck_search or classic"
|
||||
retriever_name = data["retriever"] if "retriever" in data else "classic"
|
||||
|
||||
# check if active_docs or api_key is set
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
retriever_name = data_key["retriever"] or retriever_name
|
||||
user_api_key = data["api_key"]
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
current_app.logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
@answer_ns.route("/stream")
|
||||
class Stream(Resource):
|
||||
stream_model = api.model(
|
||||
"StreamModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Chat history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"selectedDocs": fields.String(
|
||||
required=False, description="Selected documents"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
print("\033[91merr", str(message), file=sys.stderr)
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
message = e.args[0]
|
||||
status_code = 400
|
||||
# # Custom exceptions with two arguments, index 1 as status code
|
||||
if len(e.args) >= 2:
|
||||
status_code = e.args[1]
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
@api.expect(stream_model)
|
||||
@api.doc(description="Stream a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
history = json.loads(history)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
if "selectedDocs" in data and data["selectedDocs"] is None:
|
||||
chunks = 0
|
||||
else:
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
current_app.logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
print("\033[91merr", str(message), file=sys.stderr)
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
message = e.args[0]
|
||||
status_code = 400
|
||||
# Custom exceptions with two arguments, index 1 as status code
|
||||
if len(e.args) >= 2:
|
||||
status_code = e.args[1]
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def error_stream_generate(err_response):
|
||||
@@ -353,180 +384,235 @@ def error_stream_generate(err_response):
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
|
||||
@answer.route("/api/answer", methods=["POST"])
|
||||
def api_answer():
|
||||
data = request.get_json()
|
||||
question = data["question"]
|
||||
if "history" not in data:
|
||||
history = []
|
||||
else:
|
||||
history = data["history"]
|
||||
if "conversation_id" not in data:
|
||||
conversation_id = None
|
||||
else:
|
||||
conversation_id = data["conversation_id"]
|
||||
print("-" * 5)
|
||||
if "prompt_id" in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = "default"
|
||||
if "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
## retriever can be brave_search, duckduck_search or classic
|
||||
retriever_name = data["retriever"] if "retriever" in data else "classic"
|
||||
|
||||
# use try and except to check for exception
|
||||
try:
|
||||
# check if the vectorstore is set
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
retriever_name = data_key["retriever"] or retriever_name
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
current_app.logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
source_log_docs = []
|
||||
response_full = ""
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
)
|
||||
)
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_answer",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
|
||||
@answer.route("/api/search", methods=["POST"])
|
||||
def api_search():
|
||||
data = request.get_json()
|
||||
question = data["question"]
|
||||
if "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
source = {"active_docs":data_key["source"]}
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
if "retriever" in data:
|
||||
retriever_name = data["retriever"]
|
||||
else:
|
||||
retriever_name = "classic"
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
current_app.logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
docs = retriever.search()
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
@answer_ns.route("/api/answer")
|
||||
class Answer(Resource):
|
||||
answer_model = api.model(
|
||||
"AnswerModel",
|
||||
{
|
||||
"action": "api_search",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"sources": docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
"question": fields.String(
|
||||
required=True, description="The question to answer"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Conversation history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in docs:
|
||||
doc["source"] = "None"
|
||||
@api.expect(answer_model)
|
||||
@api.doc(description="Provide an answer based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
return docs
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
current_app.logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
source_log_docs = []
|
||||
response_full = ""
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
)
|
||||
)
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_answer",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(result, 200)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class Search(Resource):
|
||||
search_model = api.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The question to search"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key for authentication"
|
||||
),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents for retrieval"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"token_limit": fields.Integer(
|
||||
required=False, description="Limit for tokens"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(search_model)
|
||||
@api.doc(
|
||||
description="Search for relevant documents based on the question and retriever"
|
||||
)
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
current_app.logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
docs = retriever.search()
|
||||
retriever_params = retriever.get_params()
|
||||
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_search",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"sources": docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(docs, 200)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,19 @@
|
||||
import platform
|
||||
|
||||
import dotenv
|
||||
from application.celery_init import celery
|
||||
from flask import Flask, request, redirect
|
||||
from application.core.settings import settings
|
||||
from application.api.user.routes import user
|
||||
from flask import Flask, redirect, request
|
||||
|
||||
from application.api.answer.routes import answer
|
||||
from application.api.internal.routes import internal
|
||||
from application.api.user.routes import user
|
||||
from application.celery_init import celery
|
||||
from application.core.logging_config import setup_logging
|
||||
from application.core.settings import settings
|
||||
from application.extensions import api
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import pathlib
|
||||
|
||||
pathlib.PosixPath = pathlib.WindowsPath
|
||||
|
||||
dotenv.load_dotenv()
|
||||
@@ -23,16 +27,19 @@ app.config.update(
|
||||
UPLOAD_FOLDER="inputs",
|
||||
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
||||
CELERY_RESULT_BACKEND=settings.CELERY_RESULT_BACKEND,
|
||||
MONGO_URI=settings.MONGO_URI
|
||||
MONGO_URI=settings.MONGO_URI,
|
||||
)
|
||||
celery.config_from_object("application.celeryconfig")
|
||||
api.init_app(app)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
if request.remote_addr in ('0.0.0.0', '127.0.0.1', 'localhost', '172.18.0.1'):
|
||||
return redirect('http://localhost:5173')
|
||||
if request.remote_addr in ("0.0.0.0", "127.0.0.1", "localhost", "172.18.0.1"):
|
||||
return redirect("http://localhost:5173")
|
||||
else:
|
||||
return 'Welcome to DocsGPT Backend!'
|
||||
return "Welcome to DocsGPT Backend!"
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
@@ -41,6 +48,6 @@ def after_request(response):
|
||||
response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS")
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=settings.FLASK_DEBUG_MODE, port=7091)
|
||||
|
||||
|
||||
7
application/extensions.py
Normal file
7
application/extensions.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from flask_restx import Api
|
||||
|
||||
api = Api(
|
||||
version="1.0",
|
||||
title="DocsGPT API",
|
||||
description="API for DocsGPT",
|
||||
)
|
||||
@@ -13,6 +13,7 @@ esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
Flask==3.0.3
|
||||
faiss-cpu==1.8.0.post1
|
||||
flask-restx==1.3.0
|
||||
gunicorn==23.0.0
|
||||
html2text==2024.2.26
|
||||
javalang==0.13.0
|
||||
|
||||
@@ -1,22 +1,41 @@
|
||||
import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
|
||||
_encoding = None
|
||||
|
||||
|
||||
def get_encoding():
|
||||
global _encoding
|
||||
if _encoding is None:
|
||||
_encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return _encoding
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
encoding = get_encoding()
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
|
||||
|
||||
def count_tokens_docs(docs):
|
||||
docs_content = ""
|
||||
for doc in docs:
|
||||
docs_content += doc.page_content
|
||||
|
||||
tokens = num_tokens_from_string(docs_content)
|
||||
return tokens
|
||||
return tokens
|
||||
|
||||
|
||||
def check_required_fields(data, required_fields):
|
||||
missing_fields = [field for field in required_fields if field not in data]
|
||||
if missing_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Missing fields: {', '.join(missing_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -82,12 +82,12 @@ function Dropdown({
|
||||
}`}
|
||||
>
|
||||
{typeof selectedValue === 'string' ? (
|
||||
<span className="overflow-hidden text-ellipsis dark:text-bright-gray">
|
||||
<span className="truncate dark:text-bright-gray">
|
||||
{selectedValue}
|
||||
</span>
|
||||
) : (
|
||||
<span
|
||||
className={`truncate overflow-hidden dark:text-bright-gray ${
|
||||
className={`truncate dark:text-bright-gray ${
|
||||
!selectedValue && 'text-silver dark:text-gray-400'
|
||||
} ${contentSize}`}
|
||||
>
|
||||
|
||||
@@ -97,14 +97,13 @@ export default function CreateAPIKeyModal({
|
||||
<div className="my-4">
|
||||
<Dropdown
|
||||
placeholder={t('modals.createAPIKey.sourceDoc')}
|
||||
selectedValue={sourcePath}
|
||||
selectedValue={sourcePath ? sourcePath.name : null}
|
||||
onSelect={(selection: {
|
||||
name: string;
|
||||
id: string;
|
||||
type: string;
|
||||
}) => {
|
||||
setSourcePath(selection);
|
||||
console.log(selection);
|
||||
}}
|
||||
options={extractDocPaths()}
|
||||
size="w-full"
|
||||
|
||||
@@ -181,8 +181,8 @@ export default function Analytics() {
|
||||
border="border"
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-8 w-full flex flex-col sm:flex-row gap-3">
|
||||
<div className="h-[345px] sm:w-1/2 w-full px-6 py-5 border rounded-2xl border-silver dark:border-silver/40">
|
||||
<div className="mt-8 w-full flex flex-col [@media(min-width:1080px)]:flex-row gap-3">
|
||||
<div className="h-[345px] [@media(min-width:1080px)]:w-1/2 w-full px-6 py-5 border rounded-2xl border-silver dark:border-silver/40">
|
||||
<div className="flex flex-row items-center justify-start gap-3">
|
||||
<p className="font-bold text-jet dark:text-bright-gray">
|
||||
Messages
|
||||
@@ -227,7 +227,7 @@ export default function Analytics() {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="h-[345px] sm:w-1/2 w-full px-6 py-5 border rounded-2xl border-silver dark:border-silver/40">
|
||||
<div className="h-[345px] [@media(min-width:1080px)]:w-1/2 w-full px-6 py-5 border rounded-2xl border-silver dark:border-silver/40">
|
||||
<div className="flex flex-row items-center justify-start gap-3">
|
||||
<p className="font-bold text-jet dark:text-bright-gray">
|
||||
Token Usage
|
||||
|
||||
Reference in New Issue
Block a user