From 610adcbefc7671d56b3367e6074215480c00ea73 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 11 Jun 2023 22:56:34 +0100 Subject: [PATCH] Sources in responses --- application/app.py | 7 + frontend/src/conversation/Conversation.tsx | 2 + .../src/conversation/ConversationBubble.tsx | 183 +++++++++++------- frontend/src/conversation/conversationApi.ts | 2 +- .../src/conversation/conversationModels.ts | 2 + .../src/conversation/conversationSlice.ts | 29 ++- 6 files changed, 147 insertions(+), 78 deletions(-) diff --git a/application/app.py b/application/app.py index 6e19f03b..9c75987d 100644 --- a/application/app.py +++ b/application/app.py @@ -157,6 +157,10 @@ def complete_stream(question, docsearch, chat_history, api_key): docs_together = "\n".join([doc.page_content for doc in docs]) p_chat_combine = chat_combine_template.replace("{summaries}", docs_together) messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + data = json.dumps({"type": "source", "doc": doc.page_content}) + yield f"data:{data}\n\n" + if len(chat_history) > 1: tokens_current_history = 0 # count tokens in history @@ -308,6 +312,9 @@ def api_answer(): except Exception: pass + sources = docsearch.similarity_search(question, k=2) + result['sources'] = [{'title': i.page_content, 'text': i.page_content} for i in sources] + # mock result # result = { # "answer": "The answer is 42", diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx index 471a5cd6..7cf59ff6 100644 --- a/frontend/src/conversation/Conversation.tsx +++ b/frontend/src/conversation/Conversation.tsx @@ -60,6 +60,7 @@ export default function Conversation() { key={`${index}ANSWER`} message={query.response} type={'ANSWER'} + sources={query.sources} feedback={query.feedback} handleFeedback={(feedback: FEEDBACK) => handleFeedback(query, feedback, index) @@ -83,6 +84,7 @@ export default function Conversation() { key={`${index}QUESTION`} message={query.prompt} type="QUESTION" + sources={query.sources} > {prepResponseView(query, index)} diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index ae3db621..add7bf1a 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -8,6 +8,8 @@ import ReactMarkdown from 'react-markdown'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; import { vscDarkPlus } from 'react-syntax-highlighter/dist/cjs/styles/prism'; +const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false; + const ConversationBubble = forwardRef< HTMLDivElement, { @@ -16,12 +18,14 @@ const ConversationBubble = forwardRef< className?: string; feedback?: FEEDBACK; handleFeedback?: (feedback: FEEDBACK) => void; + sources?: { title: string; text: string }[]; } >(function ConversationBubble( - { message, type, className, feedback, handleFeedback }, + { message, type, className, feedback, handleFeedback, sources }, ref, ) { const [showFeedback, setShowFeedback] = useState(false); + const [openSource, setOpenSource] = useState(null); const List = ({ ordered, children, @@ -37,7 +41,7 @@ const ConversationBubble = forwardRef< if (type === 'QUESTION') { bubble = (
- +
{message} @@ -49,85 +53,116 @@ const ConversationBubble = forwardRef< bubble = (
setShowFeedback(true)} onMouseLeave={() => setShowFeedback(false)} > - -
- {type === 'ERROR' && ( - alert - )} - - {String(children).replace(/\n$/, '')} - - ) : ( - - {children} - - ); - }, - ul({ node, children }) { - return {children}; - }, - ol({ node, children }) { - return {children}; - }, - }} +
+ +
- {message} - -
-
- + )} + + {String(children).replace(/\n$/, '')} + + ) : ( + + {children} + + ); + }, + ul({ node, children }) { + return {children}; + }, + ol({ node, children }) { + return {children}; + }, + }} + > + {message} + +
+
handleFeedback?.('LIKE')} - > -
-
- + handleFeedback?.('LIKE')} + > +
+
handleFeedback?.('DISLIKE')} - > + > + handleFeedback?.('DISLIKE')} + > +
+
+ {DisableSourceFE + ? null + : sources?.map((source, index) => ( +
+ setOpenSource(openSource === index ? null : index) + } + > +

+ {index + 1}. {source.title} +

+
+ ))} +
+ + {sources && openSource !== null && sources[openSource] && ( +
+

Source:

+ +
+

+ {sources[openSource].text} +

+
+
+ )}
); } diff --git a/frontend/src/conversation/conversationApi.ts b/frontend/src/conversation/conversationApi.ts index 9e25d3af..c1e9c822 100644 --- a/frontend/src/conversation/conversationApi.ts +++ b/frontend/src/conversation/conversationApi.ts @@ -51,7 +51,7 @@ export function fetchAnswerApi( }) .then((data) => { const result = data.answer; - return { answer: result, query: question, result }; + return { answer: result, query: question, result, sources: data.sources }; }); } diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index 92ed976f..be8a0c6f 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -16,6 +16,7 @@ export interface Answer { answer: string; query: string; result: string; + sources: { title: string; text: string }[]; } export interface Query { @@ -23,4 +24,5 @@ export interface Query { response?: string; feedback?: FEEDBACK; error?: string; + sources?: { title: string; text: string }[]; } diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 70fa1a81..8b5d29ad 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -28,6 +28,14 @@ export const fetchAnswer = createAsyncThunk( if (data.type === 'end') { // set status to 'idle' dispatch(conversationSlice.actions.setStatus('idle')); + } else if (data.type === 'source') { + const result = data.doc; + dispatch( + updateStreamingSource({ + index: state.conversation.queries.length - 1, + query: { sources: [{ title: result, text: result }] }, + }), + ); } else { const result = data.answer; dispatch( @@ -50,7 +58,7 @@ export const fetchAnswer = createAsyncThunk( dispatch( updateQuery({ index: state.conversation.queries.length - 1, - query: { response: answer.answer }, + query: { response: answer.answer, sources: answer.sources }, }), ); dispatch(conversationSlice.actions.setStatus('idle')); @@ -83,6 +91,17 @@ export const conversationSlice = createSlice({ }; } }, + updateStreamingSource( + state, + action: PayloadAction<{ index: number; query: Partial }>, + ) { + const index = action.payload.index; + if (!state.queries[index].sources) { + state.queries[index].sources = [action.payload.query.sources![0]]; + } else { + state.queries[index].sources!.push(action.payload.query.sources![0]); + } + }, updateQuery( state, action: PayloadAction<{ index: number; query: Partial }>, @@ -116,6 +135,10 @@ export const selectQueries = (state: RootState) => state.conversation.queries; export const selectStatus = (state: RootState) => state.conversation.status; -export const { addQuery, updateQuery, updateStreamingQuery } = - conversationSlice.actions; +export const { + addQuery, + updateQuery, + updateStreamingQuery, + updateStreamingSource, +} = conversationSlice.actions; export default conversationSlice.reducer;