streaming experiments

This commit is contained in:
Alex
2023-05-18 23:52:59 +01:00
parent e49dd0cc6a
commit ff2e79fe7b
3 changed files with 70 additions and 35 deletions

View File

@@ -9,7 +9,7 @@ import dotenv
import requests import requests
from celery import Celery from celery import Celery
from celery.result import AsyncResult from celery.result import AsyncResult
from flask import Flask, request, render_template, send_from_directory, jsonify from flask import Flask, request, render_template, send_from_directory, jsonify, Response
from langchain import FAISS from langchain import FAISS
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
from langchain.chains import LLMChain, ConversationalRetrievalChain from langchain.chains import LLMChain, ConversationalRetrievalChain
@@ -120,6 +120,21 @@ def home():
embeddings_choice=settings.EMBEDDINGS_NAME) embeddings_choice=settings.EMBEDDINGS_NAME)
def complete_stream(input):
import time
for i in range(10):
data = json.dumps({"answer": i})
#data = {"answer": str(i)}
yield f"data: {data}\n\n"
time.sleep(0.05)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
@app.route("/stream", methods=['POST', 'GET'])
def stream():
return Response(complete_stream("hi"), mimetype='text/event-stream')
@app.route("/api/answer", methods=["POST"]) @app.route("/api/answer", methods=["POST"])
def api_answer(): def api_answer():
data = request.get_json() data = request.get_json()

View File

@@ -7,6 +7,7 @@ export function fetchAnswerApi(
question: string, question: string,
apiKey: string, apiKey: string,
selectedDocs: Doc, selectedDocs: Doc,
onEvent: (event: MessageEvent) => void,
): Promise<Answer> { ): Promise<Answer> {
let namePath = selectedDocs.name; let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) { if (selectedDocs.language === namePath) {
@@ -28,30 +29,23 @@ export function fetchAnswerApi(
'/'; '/';
} }
return fetch(apiHost + '/api/answer', { return new Promise<Answer>((resolve, reject) => {
method: 'POST', const url = new URL(apiHost + '/stream');
headers: { url.searchParams.append('question', question);
'Content-Type': 'application/json', url.searchParams.append('api_key', apiKey);
}, url.searchParams.append('embeddings_key', apiKey);
body: JSON.stringify({ url.searchParams.append('history', localStorage.getItem('chatHistory'));
question: question, url.searchParams.append('active_docs', docPath);
api_key: apiKey,
embeddings_key: apiKey, const eventSource = new EventSource(url.href);
history: localStorage.getItem('chatHistory'),
active_docs: docPath, eventSource.onmessage = onEvent;
}),
}) eventSource.onerror = (error) => {
.then((response) => { console.log('Connection failed.');
if (response.ok) { eventSource.close();
return response.json(); };
} else { });
Promise.reject(response);
}
})
.then((data) => {
const result = data.answer;
return { answer: result, query: question, result };
});
} }
export function sendFeedback( export function sendFeedback(

View File

@@ -1,7 +1,8 @@
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
import store from '../store'; import store from '../store';
import { fetchAnswerApi } from './conversationApi'; import { fetchAnswerApi } from './conversationApi';
import { Answer, ConversationState, Query } from './conversationModels'; import { ConversationState, Query } from './conversationModels';
import { Dispatch } from 'react';
const initialState: ConversationState = { const initialState: ConversationState = {
queries: [], queries: [],
@@ -9,18 +10,35 @@ const initialState: ConversationState = {
}; };
export const fetchAnswer = createAsyncThunk< export const fetchAnswer = createAsyncThunk<
Answer, void,
{ question: string }, { question: string },
{ state: RootState } { dispatch: Dispatch; state: RootState }
>('fetchAnswer', async ({ question }, { getState }) => { >('fetchAnswer', ({ question }, { dispatch, getState }) => {
const state = getState(); const state = getState();
const answer = await fetchAnswerApi( fetchAnswerApi(
question, question,
state.preference.apiKey, state.preference.apiKey,
state.preference.selectedDocs!, state.preference.selectedDocs!,
(event) => {
const data = JSON.parse(event.data);
console.log(data);
// check if the 'end' event has been received
if (data.type === 'end') {
// set status to 'idle'
dispatch(conversationSlice.actions.setStatus('idle'));
} else {
const result = JSON.stringify(data.answer);
dispatch(
updateQuery({
index: state.conversation.queries.length - 1,
query: { response: result },
}),
);
}
},
); );
return answer;
}); });
export const conversationSlice = createSlice({ export const conversationSlice = createSlice({
@@ -35,10 +53,18 @@ export const conversationSlice = createSlice({
action: PayloadAction<{ index: number; query: Partial<Query> }>, action: PayloadAction<{ index: number; query: Partial<Query> }>,
) { ) {
const index = action.payload.index; const index = action.payload.index;
state.queries[index] = { if (action.payload.query.response) {
...state.queries[index], state.queries[index].response =
...action.payload.query, (state.queries[index].response || '') + action.payload.query.response;
}; } else {
state.queries[index] = {
...state.queries[index],
...action.payload.query,
};
}
},
setStatus(state, action: PayloadAction<string>) {
state.status = action.payload;
}, },
}, },
extraReducers(builder) { extraReducers(builder) {