Merge branch 'main' of https://github.com/utin-francis-peter/DocsGPT into fix/traning-progress

This commit is contained in:
utin-francis-peter
2024-06-20 22:33:44 +01:00
3 changed files with 78 additions and 34 deletions

View File

@@ -9,13 +9,11 @@ import traceback
from pymongo import MongoClient from pymongo import MongoClient
from bson.objectid import ObjectId from bson.objectid import ObjectId
from application.core.settings import settings from application.core.settings import settings
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request from application.error import bad_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
mongo = MongoClient(settings.MONGO_URI) mongo = MongoClient(settings.MONGO_URI)
@@ -75,8 +73,10 @@ def run_async_chain(chain, question, chat_history):
def get_data_from_api_key(api_key): def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key}) data = api_key_collection.find_one({"key": api_key})
# # Raise custom exception if the API key is not found
if data is None: if data is None:
return bad_request(401, "Invalid API key") raise Exception("Invalid API Key, please generate new key", 401)
return data return data
@@ -128,10 +128,10 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
"content": "Summarise following conversation in no more than 3 " "content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same " "words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: " "language as the system \n\nUser: "
+ question +question
+ "\n\n" +"\n\n"
+ "AI: " +"AI: "
+ response, +response,
}, },
{ {
"role": "user", "role": "user",
@@ -173,33 +173,39 @@ def get_prompt(prompt_id):
def complete_stream(question, retriever, conversation_id, user_api_key): def complete_stream(question, retriever, conversation_id, user_api_key):
response_full = "" try:
source_log_docs = [] response_full = ""
answer = retriever.gen() source_log_docs = []
for line in answer: answer = retriever.gen()
if "answer" in line: for line in answer:
response_full += str(line["answer"]) if "answer" in line:
data = json.dumps(line) response_full += str(line["answer"])
yield f"data: {data}\n\n" data = json.dumps(line)
elif "source" in line: yield f"data: {data}\n\n"
source_log_docs.append(line["source"]) elif "source" in line:
source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
data = json.dumps({"type": "error","error":"Please try again later. We apologize for any inconvenience.",
"error_exception": str(e)})
yield f"data: {data}\n\n"
return
@answer.route("/stream", methods=["POST"]) @answer.route("/stream", methods=["POST"])
def stream(): def stream():
try:
data = request.get_json() data = request.get_json()
# get parameter from url question # get parameter from url question
question = data["question"] question = data["question"]
@@ -273,7 +279,29 @@ def stream():
), ),
mimetype="text/event-stream", mimetype="text/event-stream",
) )
except ValueError:
message = "Malformed request body"
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
print("err",str(e))
message = e.args[0]
status_code = 400
# # Custom exceptions with two arguments, index 1 as status code
if(len(e.args) >= 2):
status_code = e.args[1]
return Response(
error_stream_generate(message),
status=status_code,
mimetype="text/event-stream",
)
def error_stream_generate(err_response):
data = json.dumps({"type": "error", "error":err_response})
yield f"data: {data}\n\n"
@answer.route("/api/answer", methods=["POST"]) @answer.route("/api/answer", methods=["POST"])
def api_answer(): def api_answer():

View File

@@ -257,8 +257,8 @@ def combined_json():
} }
] ]
# structure: name, language, version, description, fullName, date, docLink # structure: name, language, version, description, fullName, date, docLink
# append data from vectors_collection # append data from vectors_collection in sorted order in descending order of date
for index in vectors_collection.find({"user": user}): for index in vectors_collection.find({"user": user}).sort("date", -1):
data.append( data.append(
{ {
"name": index["name"], "name": index["name"],

View File

@@ -68,6 +68,15 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
query: { conversationId: data.id }, query: { conversationId: data.id },
}), }),
); );
} else if (data.type === 'error') {
// set status to 'failed'
dispatch(conversationSlice.actions.setStatus('failed'));
dispatch(
conversationSlice.actions.raiseError({
index: state.conversation.queries.length - 1,
message: data.error,
}),
);
} else { } else {
const result = data.answer; const result = data.answer;
dispatch( dispatch(
@@ -191,6 +200,13 @@ export const conversationSlice = createSlice({
setStatus(state, action: PayloadAction<Status>) { setStatus(state, action: PayloadAction<Status>) {
state.status = action.payload; state.status = action.payload;
}, },
raiseError(
state,
action: PayloadAction<{ index: number; message: string }>,
) {
const { index, message } = action.payload;
state.queries[index].error = message;
},
}, },
extraReducers(builder) { extraReducers(builder) {
builder builder
@@ -204,7 +220,7 @@ export const conversationSlice = createSlice({
} }
state.status = 'failed'; state.status = 'failed';
state.queries[state.queries.length - 1].error = state.queries[state.queries.length - 1].error =
'Something went wrong. Please try again later.'; 'Something went wrong. Please check your internet connection.';
}); });
}, },
}); });