mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge branch 'main' of https://github.com/utin-francis-peter/DocsGPT into fix/traning-progress
This commit is contained in:
@@ -9,13 +9,11 @@ import traceback
|
|||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.retriever.retriever_creator import RetrieverCreator
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
from application.error import bad_request
|
from application.error import bad_request
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
mongo = MongoClient(settings.MONGO_URI)
|
mongo = MongoClient(settings.MONGO_URI)
|
||||||
@@ -75,8 +73,10 @@ def run_async_chain(chain, question, chat_history):
|
|||||||
|
|
||||||
def get_data_from_api_key(api_key):
|
def get_data_from_api_key(api_key):
|
||||||
data = api_key_collection.find_one({"key": api_key})
|
data = api_key_collection.find_one({"key": api_key})
|
||||||
|
|
||||||
|
# # Raise custom exception if the API key is not found
|
||||||
if data is None:
|
if data is None:
|
||||||
return bad_request(401, "Invalid API key")
|
raise Exception("Invalid API Key, please generate new key", 401)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -128,10 +128,10 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
|
|||||||
"content": "Summarise following conversation in no more than 3 "
|
"content": "Summarise following conversation in no more than 3 "
|
||||||
"words, respond ONLY with the summary, use the same "
|
"words, respond ONLY with the summary, use the same "
|
||||||
"language as the system \n\nUser: "
|
"language as the system \n\nUser: "
|
||||||
+ question
|
+question
|
||||||
+ "\n\n"
|
+"\n\n"
|
||||||
+ "AI: "
|
+"AI: "
|
||||||
+ response,
|
+response,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -173,33 +173,39 @@ def get_prompt(prompt_id):
|
|||||||
|
|
||||||
def complete_stream(question, retriever, conversation_id, user_api_key):
|
def complete_stream(question, retriever, conversation_id, user_api_key):
|
||||||
|
|
||||||
response_full = ""
|
try:
|
||||||
source_log_docs = []
|
response_full = ""
|
||||||
answer = retriever.gen()
|
source_log_docs = []
|
||||||
for line in answer:
|
answer = retriever.gen()
|
||||||
if "answer" in line:
|
for line in answer:
|
||||||
response_full += str(line["answer"])
|
if "answer" in line:
|
||||||
data = json.dumps(line)
|
response_full += str(line["answer"])
|
||||||
yield f"data: {data}\n\n"
|
data = json.dumps(line)
|
||||||
elif "source" in line:
|
yield f"data: {data}\n\n"
|
||||||
source_log_docs.append(line["source"])
|
elif "source" in line:
|
||||||
|
source_log_docs.append(line["source"])
|
||||||
llm = LLMCreator.create_llm(
|
|
||||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
|
||||||
)
|
|
||||||
conversation_id = save_conversation(
|
|
||||||
conversation_id, question, response_full, source_log_docs, llm
|
|
||||||
)
|
|
||||||
|
|
||||||
# send data.type = "end" to indicate that the stream has ended as json
|
|
||||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
data = json.dumps({"type": "end"})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
|
|
||||||
|
llm = LLMCreator.create_llm(
|
||||||
|
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||||
|
)
|
||||||
|
conversation_id = save_conversation(
|
||||||
|
conversation_id, question, response_full, source_log_docs, llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# send data.type = "end" to indicate that the stream has ended as json
|
||||||
|
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
data = json.dumps({"type": "end"})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
data = json.dumps({"type": "error","error":"Please try again later. We apologize for any inconvenience.",
|
||||||
|
"error_exception": str(e)})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
@answer.route("/stream", methods=["POST"])
|
@answer.route("/stream", methods=["POST"])
|
||||||
def stream():
|
def stream():
|
||||||
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
# get parameter from url question
|
# get parameter from url question
|
||||||
question = data["question"]
|
question = data["question"]
|
||||||
@@ -273,7 +279,29 @@ def stream():
|
|||||||
),
|
),
|
||||||
mimetype="text/event-stream",
|
mimetype="text/event-stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
message = "Malformed request body"
|
||||||
|
return Response(
|
||||||
|
error_stream_generate(message),
|
||||||
|
status=400,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("err",str(e))
|
||||||
|
message = e.args[0]
|
||||||
|
status_code = 400
|
||||||
|
# # Custom exceptions with two arguments, index 1 as status code
|
||||||
|
if(len(e.args) >= 2):
|
||||||
|
status_code = e.args[1]
|
||||||
|
return Response(
|
||||||
|
error_stream_generate(message),
|
||||||
|
status=status_code,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
def error_stream_generate(err_response):
|
||||||
|
data = json.dumps({"type": "error", "error":err_response})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
@answer.route("/api/answer", methods=["POST"])
|
@answer.route("/api/answer", methods=["POST"])
|
||||||
def api_answer():
|
def api_answer():
|
||||||
|
|||||||
@@ -257,8 +257,8 @@ def combined_json():
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
# structure: name, language, version, description, fullName, date, docLink
|
# structure: name, language, version, description, fullName, date, docLink
|
||||||
# append data from vectors_collection
|
# append data from vectors_collection in sorted order in descending order of date
|
||||||
for index in vectors_collection.find({"user": user}):
|
for index in vectors_collection.find({"user": user}).sort("date", -1):
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
"name": index["name"],
|
"name": index["name"],
|
||||||
|
|||||||
@@ -68,6 +68,15 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
|||||||
query: { conversationId: data.id },
|
query: { conversationId: data.id },
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
} else if (data.type === 'error') {
|
||||||
|
// set status to 'failed'
|
||||||
|
dispatch(conversationSlice.actions.setStatus('failed'));
|
||||||
|
dispatch(
|
||||||
|
conversationSlice.actions.raiseError({
|
||||||
|
index: state.conversation.queries.length - 1,
|
||||||
|
message: data.error,
|
||||||
|
}),
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
const result = data.answer;
|
const result = data.answer;
|
||||||
dispatch(
|
dispatch(
|
||||||
@@ -191,6 +200,13 @@ export const conversationSlice = createSlice({
|
|||||||
setStatus(state, action: PayloadAction<Status>) {
|
setStatus(state, action: PayloadAction<Status>) {
|
||||||
state.status = action.payload;
|
state.status = action.payload;
|
||||||
},
|
},
|
||||||
|
raiseError(
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ index: number; message: string }>,
|
||||||
|
) {
|
||||||
|
const { index, message } = action.payload;
|
||||||
|
state.queries[index].error = message;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers(builder) {
|
||||||
builder
|
builder
|
||||||
@@ -204,7 +220,7 @@ export const conversationSlice = createSlice({
|
|||||||
}
|
}
|
||||||
state.status = 'failed';
|
state.status = 'failed';
|
||||||
state.queries[state.queries.length - 1].error =
|
state.queries[state.queries.length - 1].error =
|
||||||
'Something went wrong. Please try again later.';
|
'Something went wrong. Please check your internet connection.';
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user