mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Made the requested changes
This commit is contained in:
@@ -118,8 +118,19 @@ def is_azure_configured():
|
||||
)
|
||||
|
||||
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm):
|
||||
if conversation_id is not None and conversation_id != "None":
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
@@ -186,7 +197,7 @@ def get_prompt(prompt_id):
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
|
||||
):
|
||||
|
||||
try:
|
||||
@@ -217,7 +228,7 @@ def complete_stream(
|
||||
)
|
||||
if user_api_key is None:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
conversation_id, question, response_full, source_log_docs, llm,index
|
||||
)
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
@@ -282,6 +293,9 @@ class Stream(Resource):
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index":fields.Integer(
|
||||
required=False, description="The position where query is to be updated"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -290,19 +304,20 @@ class Stream(Resource):
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
|
||||
if "index" in data:
|
||||
required_fields = ["question","conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
history = json.loads(history)
|
||||
history = str(data.get("history", []))
|
||||
history = str(json.loads(history))
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
|
||||
|
||||
index=data.get("index",None);
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
@@ -351,6 +366,7 @@ class Stream(Resource):
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=index,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -168,48 +168,6 @@ class UpdateConversationName(Resource):
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@user_ns.route("/api/update_conversation_queries")
|
||||
class UpdateConversationQueries(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateConversationQueriesModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Conversation ID"),
|
||||
"limit": fields.Integer(
|
||||
required=True, description="Number by which queries should be sliced."
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates the queries in a conversation",
|
||||
)
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "limit"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"])},[{
|
||||
"$set": {
|
||||
"queries": {
|
||||
"$slice": ["$queries", data["limit"]]
|
||||
}
|
||||
}
|
||||
}])
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(data["id"])}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
except Exception as err:
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
return make_response(jsonify(conversation["queries"]), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/feedback")
|
||||
class SubmitFeedback(Resource):
|
||||
|
||||
@@ -17,8 +17,7 @@ const endpoints = {
|
||||
TOKEN_ANALYTICS: '/api/get_token_analytics',
|
||||
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
|
||||
LOGS: `/api/get_user_logs`,
|
||||
MANAGE_SYNC: '/api/manage_sync',
|
||||
UPDATE_CONVERSATION_QUERIES: '/api/update_conversation_queries',
|
||||
MANAGE_SYNC: '/api/manage_sync'
|
||||
},
|
||||
CONVERSATION: {
|
||||
ANSWER: '/api/answer',
|
||||
|
||||
@@ -27,8 +27,6 @@ const conversationService = {
|
||||
apiClient.get(endpoints.CONVERSATION.DELETE_ALL),
|
||||
update: (data: any): Promise<any> =>
|
||||
apiClient.post(endpoints.CONVERSATION.UPDATE, data),
|
||||
update_conversation_queries: (data: any): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.UPDATE_CONVERSATION_QUERIES, data),
|
||||
};
|
||||
|
||||
export default conversationService;
|
||||
|
||||
@@ -22,6 +22,7 @@ import { FEEDBACK, Query } from './conversationModels';
|
||||
import {
|
||||
addQuery,
|
||||
fetchAnswer,
|
||||
resendQuery,
|
||||
selectQueries,
|
||||
selectStatus,
|
||||
setConversation,
|
||||
@@ -87,25 +88,17 @@ export default function Conversation() {
|
||||
question,
|
||||
isRetry = false,
|
||||
updated = null,
|
||||
indx = null,
|
||||
indx = undefined,
|
||||
}: {
|
||||
question: string;
|
||||
isRetry?: boolean;
|
||||
updated?: boolean | null;
|
||||
indx?: number | null;
|
||||
indx?: number;
|
||||
}) => {
|
||||
if (updated === true) {
|
||||
conversationService
|
||||
.update_conversation_queries({
|
||||
id: conversationId,
|
||||
limit: indx,
|
||||
})
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
dispatch(setConversation(data));
|
||||
!isRetry && dispatch(addQuery({ prompt: question })); //dispatch only new queries
|
||||
fetchStream.current = dispatch(fetchAnswer({ question }));
|
||||
});
|
||||
!isRetry &&
|
||||
dispatch(resendQuery({ index: indx as number, prompt: question })); //dispatch only new queries
|
||||
fetchStream.current = dispatch(fetchAnswer({ question, indx }));
|
||||
} else {
|
||||
question = question.trim();
|
||||
if (question === '') return;
|
||||
|
||||
@@ -75,6 +75,7 @@ export function handleFetchAnswerSteaming(
|
||||
chunks: string,
|
||||
token_limit: number,
|
||||
onEvent: (event: MessageEvent) => void,
|
||||
indx?: number,
|
||||
): Promise<Answer> {
|
||||
history = history.map((item) => {
|
||||
return { prompt: item.prompt, response: item.response };
|
||||
@@ -87,6 +88,7 @@ export function handleFetchAnswerSteaming(
|
||||
chunks: chunks,
|
||||
token_limit: token_limit,
|
||||
isNoneDoc: selectedDocs === null,
|
||||
index: indx,
|
||||
};
|
||||
if (selectedDocs && 'id' in selectedDocs) {
|
||||
payload.active_docs = selectedDocs.id as string;
|
||||
|
||||
@@ -41,4 +41,5 @@ export interface RetrievalPayload {
|
||||
chunks: string;
|
||||
token_limit: number;
|
||||
isNoneDoc: boolean;
|
||||
index?: number;
|
||||
}
|
||||
|
||||
@@ -25,9 +25,10 @@ export function handleAbort() {
|
||||
}
|
||||
}
|
||||
|
||||
export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
'fetchAnswer',
|
||||
async ({ question }, { dispatch, getState }) => {
|
||||
export const fetchAnswer = createAsyncThunk<
|
||||
Answer,
|
||||
{ question: string; indx?: number }
|
||||
>('fetchAnswer', async ({ question, indx }, { dispatch, getState }) => {
|
||||
if (abortController) {
|
||||
abortController.abort();
|
||||
}
|
||||
@@ -47,7 +48,6 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
state.preference.prompt.id,
|
||||
state.preference.chunks,
|
||||
state.preference.token_limit,
|
||||
|
||||
(event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
@@ -63,7 +63,7 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
if (!isSourceUpdated) {
|
||||
dispatch(
|
||||
updateStreamingSource({
|
||||
index: state.conversation.queries.length - 1,
|
||||
index: indx ?? state.conversation.queries.length - 1,
|
||||
query: { sources: [] },
|
||||
}),
|
||||
);
|
||||
@@ -78,7 +78,7 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
isSourceUpdated = true;
|
||||
dispatch(
|
||||
updateStreamingSource({
|
||||
index: state.conversation.queries.length - 1,
|
||||
index: indx ?? state.conversation.queries.length - 1,
|
||||
query: { sources: data.source ?? [] },
|
||||
}),
|
||||
);
|
||||
@@ -87,7 +87,7 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
dispatch(conversationSlice.actions.setStatus('failed'));
|
||||
dispatch(
|
||||
conversationSlice.actions.raiseError({
|
||||
index: state.conversation.queries.length - 1,
|
||||
index: indx ?? state.conversation.queries.length - 1,
|
||||
message: data.error,
|
||||
}),
|
||||
);
|
||||
@@ -95,12 +95,13 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
const result = data.answer;
|
||||
dispatch(
|
||||
updateStreamingQuery({
|
||||
index: state.conversation.queries.length - 1,
|
||||
index: indx ?? state.conversation.queries.length - 1,
|
||||
query: { response: result },
|
||||
}),
|
||||
);
|
||||
}
|
||||
},
|
||||
indx,
|
||||
);
|
||||
} else {
|
||||
const answer = await handleFetchAnswer(
|
||||
@@ -128,7 +129,7 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
|
||||
dispatch(
|
||||
updateQuery({
|
||||
index: state.conversation.queries.length - 1,
|
||||
index: indx ?? state.conversation.queries.length - 1,
|
||||
query: { response: answer.answer, sources: sourcesPrepped },
|
||||
}),
|
||||
);
|
||||
@@ -156,7 +157,7 @@ export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
result: '',
|
||||
sources: [],
|
||||
};
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
export const conversationSlice = createSlice({
|
||||
@@ -169,6 +170,12 @@ export const conversationSlice = createSlice({
|
||||
setConversation(state, action: PayloadAction<Query[]>) {
|
||||
state.queries = action.payload;
|
||||
},
|
||||
resendQuery(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; prompt: string; query?: Query }>,
|
||||
) {
|
||||
state.queries[action.payload.index] = action.payload;
|
||||
},
|
||||
updateStreamingQuery(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; query: Partial<Query> }>,
|
||||
@@ -250,6 +257,7 @@ export const selectStatus = (state: RootState) => state.conversation.status;
|
||||
export const {
|
||||
addQuery,
|
||||
updateQuery,
|
||||
resendQuery,
|
||||
updateStreamingQuery,
|
||||
updateConversationId,
|
||||
updateStreamingSource,
|
||||
|
||||
Reference in New Issue
Block a user