diff --git a/application/vectorstore/milvus.py b/application/vectorstore/milvus.py index 9871991e..d4380666 100644 --- a/application/vectorstore/milvus.py +++ b/application/vectorstore/milvus.py @@ -7,7 +7,7 @@ from application.vectorstore.base import BaseVectorStore class MilvusStore(BaseVectorStore): - def __init__(self, path: str = "", embeddings_key: str = "embeddings"): + def __init__(self, source_id: str = "", embeddings_key: str = "embeddings"): super().__init__() from langchain_milvus import Milvus @@ -20,10 +20,11 @@ class MilvusStore(BaseVectorStore): collection_name=settings.MILVUS_COLLECTION_NAME, connection_args=connection_args, ) - self._path = path + self._source_id = source_id def search(self, question, k=2, *args, **kwargs): - return self._docsearch.similarity_search(query=question, k=k, filter={"path": self._path} *args, **kwargs) + expr = f"source_id == '{self._source_id}'" + return self._docsearch.similarity_search(query=question, k=k, expr=expr, *args, **kwargs) def add_texts(self, texts: List[str], metadatas: Optional[List[dict]], *args, **kwargs): ids = [str(uuid4()) for _ in range(len(texts))] diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 1b7e9d41..3ca05c7b 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -20,6 +20,7 @@ const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true'; export const fetchAnswer = createAsyncThunk( 'fetchAnswer', async ({ question }, { dispatch, getState, signal }) => { + let isSourceUpdated = false; const state = getState() as RootState; if (state.preference) { if (API_STREAMING) { @@ -36,9 +37,7 @@ export const fetchAnswer = createAsyncThunk( (event) => { const data = JSON.parse(event.data); - // check if the 'end' event has been received if (data.type === 'end') { - // set status to 'idle' dispatch(conversationSlice.actions.setStatus('idle')); getConversations() .then((fetchedConversations) => { @@ -47,6 +46,14 @@ export const fetchAnswer = createAsyncThunk( .catch((error) => { console.error('Failed to fetch conversations: ', error); }); + if (!isSourceUpdated) { + dispatch( + updateStreamingSource({ + index: state.conversation.queries.length - 1, + query: { sources: [] }, + }), + ); + } } else if (data.type === 'id') { dispatch( updateConversationId({ @@ -54,6 +61,7 @@ export const fetchAnswer = createAsyncThunk( }), ); } else if (data.type === 'source') { + isSourceUpdated = true; dispatch( updateStreamingSource({ index: state.conversation.queries.length - 1,