From ff2e79fe7b65b0f79abbe84479fef496bb950c73 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 18 May 2023 23:52:59 +0100 Subject: [PATCH] streaming experiments --- application/app.py | 17 ++++++- frontend/src/conversation/conversationApi.ts | 42 ++++++++--------- .../src/conversation/conversationSlice.ts | 46 +++++++++++++++---- 3 files changed, 70 insertions(+), 35 deletions(-) diff --git a/application/app.py b/application/app.py index d68c5b93..0307201a 100644 --- a/application/app.py +++ b/application/app.py @@ -9,7 +9,7 @@ import dotenv import requests from celery import Celery from celery.result import AsyncResult -from flask import Flask, request, render_template, send_from_directory, jsonify +from flask import Flask, request, render_template, send_from_directory, jsonify, Response from langchain import FAISS from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI from langchain.chains import LLMChain, ConversationalRetrievalChain @@ -120,6 +120,21 @@ def home(): embeddings_choice=settings.EMBEDDINGS_NAME) +def complete_stream(input): + import time + for i in range(10): + data = json.dumps({"answer": i}) + #data = {"answer": str(i)} + yield f"data: {data}\n\n" + time.sleep(0.05) + # send data.type = "end" to indicate that the stream has ended as json + data = json.dumps({"type": "end"}) + yield f"data: {data}\n\n" +@app.route("/stream", methods=['POST', 'GET']) +def stream(): + return Response(complete_stream("hi"), mimetype='text/event-stream') + + @app.route("/api/answer", methods=["POST"]) def api_answer(): data = request.get_json() diff --git a/frontend/src/conversation/conversationApi.ts b/frontend/src/conversation/conversationApi.ts index c7320342..6c3ff03e 100644 --- a/frontend/src/conversation/conversationApi.ts +++ b/frontend/src/conversation/conversationApi.ts @@ -7,6 +7,7 @@ export function fetchAnswerApi( question: string, apiKey: string, selectedDocs: Doc, + onEvent: (event: MessageEvent) => void, ): Promise { let namePath = selectedDocs.name; if (selectedDocs.language === namePath) { @@ -28,30 +29,23 @@ export function fetchAnswerApi( '/'; } - return fetch(apiHost + '/api/answer', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - question: question, - api_key: apiKey, - embeddings_key: apiKey, - history: localStorage.getItem('chatHistory'), - active_docs: docPath, - }), - }) - .then((response) => { - if (response.ok) { - return response.json(); - } else { - Promise.reject(response); - } - }) - .then((data) => { - const result = data.answer; - return { answer: result, query: question, result }; - }); + return new Promise((resolve, reject) => { + const url = new URL(apiHost + '/stream'); + url.searchParams.append('question', question); + url.searchParams.append('api_key', apiKey); + url.searchParams.append('embeddings_key', apiKey); + url.searchParams.append('history', localStorage.getItem('chatHistory')); + url.searchParams.append('active_docs', docPath); + + const eventSource = new EventSource(url.href); + + eventSource.onmessage = onEvent; + + eventSource.onerror = (error) => { + console.log('Connection failed.'); + eventSource.close(); + }; + }); } export function sendFeedback( diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index c728b9e0..04dd0be6 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -1,7 +1,8 @@ import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit'; import store from '../store'; import { fetchAnswerApi } from './conversationApi'; -import { Answer, ConversationState, Query } from './conversationModels'; +import { ConversationState, Query } from './conversationModels'; +import { Dispatch } from 'react'; const initialState: ConversationState = { queries: [], @@ -9,18 +10,35 @@ const initialState: ConversationState = { }; export const fetchAnswer = createAsyncThunk< - Answer, + void, { question: string }, - { state: RootState } ->('fetchAnswer', async ({ question }, { getState }) => { + { dispatch: Dispatch; state: RootState } +>('fetchAnswer', ({ question }, { dispatch, getState }) => { const state = getState(); - const answer = await fetchAnswerApi( + fetchAnswerApi( question, state.preference.apiKey, state.preference.selectedDocs!, + (event) => { + const data = JSON.parse(event.data); + console.log(data); + + // check if the 'end' event has been received + if (data.type === 'end') { + // set status to 'idle' + dispatch(conversationSlice.actions.setStatus('idle')); + } else { + const result = JSON.stringify(data.answer); + dispatch( + updateQuery({ + index: state.conversation.queries.length - 1, + query: { response: result }, + }), + ); + } + }, ); - return answer; }); export const conversationSlice = createSlice({ @@ -35,10 +53,18 @@ export const conversationSlice = createSlice({ action: PayloadAction<{ index: number; query: Partial }>, ) { const index = action.payload.index; - state.queries[index] = { - ...state.queries[index], - ...action.payload.query, - }; + if (action.payload.query.response) { + state.queries[index].response = + (state.queries[index].response || '') + action.payload.query.response; + } else { + state.queries[index] = { + ...state.queries[index], + ...action.payload.query, + }; + } + }, + setStatus(state, action: PayloadAction) { + state.status = action.payload; }, }, extraReducers(builder) {