mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
@@ -5,11 +5,12 @@ import json
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import openai
|
||||
import dotenv
|
||||
import requests
|
||||
from celery import Celery
|
||||
from celery.result import AsyncResult
|
||||
from flask import Flask, request, render_template, send_from_directory, jsonify
|
||||
from flask import Flask, request, render_template, send_from_directory, jsonify, Response
|
||||
from langchain import FAISS
|
||||
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
|
||||
from langchain.chains import LLMChain, ConversationalRetrievalChain
|
||||
@@ -109,7 +110,32 @@ def run_async_chain(chain, question, chat_history):
|
||||
result["answer"] = answer
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def get_vectorstore(data):
|
||||
if "active_docs" in data:
|
||||
if data["active_docs"].split("/")[0] == "local":
|
||||
if data["active_docs"].split("/")[1] == "default":
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = "indexes/" + data["active_docs"]
|
||||
else:
|
||||
vectorstore = "vectors/" + data["active_docs"]
|
||||
if data['active_docs'] == "default":
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = ""
|
||||
return vectorstore
|
||||
|
||||
def get_docsearch(vectorstore, embeddings_key):
|
||||
if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002":
|
||||
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key))
|
||||
elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
|
||||
elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
|
||||
docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
|
||||
elif settings.EMBEDDINGS_NAME == "cohere_medium":
|
||||
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
|
||||
return docsearch
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -123,6 +149,53 @@ def home():
|
||||
return render_template("index.html", api_key_set=api_key_set, llm_choice=settings.LLM_NAME,
|
||||
embeddings_choice=settings.EMBEDDINGS_NAME)
|
||||
|
||||
def complete_stream(question, docsearch, chat_history, api_key):
|
||||
openai.api_key = api_key
|
||||
docs = docsearch.similarity_search(question, k=2)
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
# swap {summaries} in chat_combine_template with the summaries from the docs
|
||||
p_chat_combine = chat_combine_template.replace("{summaries}", docs_together)
|
||||
completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[
|
||||
{"role": "system", "content": p_chat_combine},
|
||||
{"role": "user", "content": question},
|
||||
], stream=True, max_tokens=1000, temperature=0)
|
||||
|
||||
for line in completion:
|
||||
if 'content' in line['choices'][0]['delta']:
|
||||
# check if the delta contains content
|
||||
data = json.dumps({"answer": str(line['choices'][0]['delta']['content'])})
|
||||
yield f"data: {data}\n\n"
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
@app.route("/stream", methods=['POST', 'GET'])
|
||||
def stream():
|
||||
# get parameter from url question
|
||||
question = request.args.get('question')
|
||||
history = request.args.get('history')
|
||||
# check if active_docs is set
|
||||
|
||||
if not api_key_set:
|
||||
api_key = request.args.get("api_key")
|
||||
else:
|
||||
api_key = settings.API_KEY
|
||||
if not embeddings_key_set:
|
||||
embeddings_key = request.args.get("embeddings_key")
|
||||
else:
|
||||
embeddings_key = settings.EMBEDDINGS_KEY
|
||||
if "active_docs" in request.args:
|
||||
vectorstore = get_vectorstore({"active_docs": request.args.get("active_docs")})
|
||||
else:
|
||||
vectorstore = ""
|
||||
docsearch = get_docsearch(vectorstore, embeddings_key)
|
||||
|
||||
|
||||
#question = "Hi"
|
||||
return Response(complete_stream(question, docsearch,
|
||||
chat_history= history, api_key=api_key), mimetype='text/event-stream')
|
||||
|
||||
|
||||
@app.route("/api/answer", methods=["POST"])
|
||||
def api_answer():
|
||||
@@ -142,32 +215,11 @@ def api_answer():
|
||||
# use try and except to check for exception
|
||||
try:
|
||||
# check if the vectorstore is set
|
||||
if "active_docs" in data:
|
||||
if data["active_docs"].split("/")[0] == "local":
|
||||
if data["active_docs"].split("/")[1] == "default":
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = "indexes/" + data["active_docs"]
|
||||
else:
|
||||
vectorstore = "vectors/" + data["active_docs"]
|
||||
if data['active_docs'] == "default":
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = ""
|
||||
print(vectorstore)
|
||||
# vectorstore = "outputs/inputs/"
|
||||
vectorstore = get_vectorstore(data)
|
||||
# loading the index and the store and the prompt template
|
||||
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
|
||||
if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002":
|
||||
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key))
|
||||
elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
|
||||
elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
|
||||
docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
|
||||
elif settings.EMBEDDINGS_NAME == "cohere_medium":
|
||||
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
|
||||
docsearch = get_docsearch(vectorstore, embeddings_key)
|
||||
|
||||
# create a prompt template
|
||||
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
|
||||
template_format="jinja2")
|
||||
if settings.LLM_NAME == "openai_chat":
|
||||
|
||||
@@ -5,6 +5,7 @@ services:
|
||||
build: ./frontend
|
||||
environment:
|
||||
- VITE_API_HOST=http://localhost:5001
|
||||
- VITE_API_STREAMING=$VITE_API_STREAMING
|
||||
ports:
|
||||
- "5173:5173"
|
||||
depends_on:
|
||||
|
||||
@@ -46,7 +46,7 @@ export function fetchAnswerApi(
|
||||
if (response.ok) {
|
||||
return response.json();
|
||||
} else {
|
||||
Promise.reject(response);
|
||||
return Promise.reject(new Error(response.statusText));
|
||||
}
|
||||
})
|
||||
.then((data) => {
|
||||
@@ -55,6 +55,51 @@ export function fetchAnswerApi(
|
||||
});
|
||||
}
|
||||
|
||||
export function fetchAnswerSteaming(
|
||||
question: string,
|
||||
apiKey: string,
|
||||
selectedDocs: Doc,
|
||||
onEvent: (event: MessageEvent) => void,
|
||||
): Promise<Answer> {
|
||||
let namePath = selectedDocs.name;
|
||||
if (selectedDocs.language === namePath) {
|
||||
namePath = '.project';
|
||||
}
|
||||
|
||||
let docPath = 'default';
|
||||
if (selectedDocs.location === 'local') {
|
||||
docPath = 'local' + '/' + selectedDocs.name + '/';
|
||||
} else if (selectedDocs.location === 'remote') {
|
||||
docPath =
|
||||
selectedDocs.language +
|
||||
'/' +
|
||||
namePath +
|
||||
'/' +
|
||||
selectedDocs.version +
|
||||
'/' +
|
||||
selectedDocs.model +
|
||||
'/';
|
||||
}
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
const url = new URL(apiHost + '/stream');
|
||||
url.searchParams.append('question', question);
|
||||
url.searchParams.append('api_key', apiKey);
|
||||
url.searchParams.append('embeddings_key', apiKey);
|
||||
url.searchParams.append('history', localStorage.getItem('chatHistory'));
|
||||
url.searchParams.append('active_docs', docPath);
|
||||
|
||||
const eventSource = new EventSource(url.href);
|
||||
|
||||
eventSource.onmessage = onEvent;
|
||||
|
||||
eventSource.onerror = (error) => {
|
||||
console.log('Connection failed.');
|
||||
eventSource.close();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
export function sendFeedback(
|
||||
prompt: string,
|
||||
response: string,
|
||||
|
||||
@@ -1,28 +1,64 @@
|
||||
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import store from '../store';
|
||||
import { fetchAnswerApi } from './conversationApi';
|
||||
import { Answer, ConversationState, Query } from './conversationModels';
|
||||
import { fetchAnswerApi, fetchAnswerSteaming } from './conversationApi';
|
||||
import { Answer, ConversationState, Query, Status } from './conversationModels';
|
||||
|
||||
const initialState: ConversationState = {
|
||||
queries: [],
|
||||
status: 'idle',
|
||||
};
|
||||
|
||||
export const fetchAnswer = createAsyncThunk<
|
||||
Answer,
|
||||
{ question: string },
|
||||
{ state: RootState }
|
||||
>('fetchAnswer', async ({ question }, { getState }) => {
|
||||
const state = getState();
|
||||
const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true';
|
||||
|
||||
const answer = await fetchAnswerApi(
|
||||
question,
|
||||
state.preference.apiKey,
|
||||
state.preference.selectedDocs!,
|
||||
state.conversation.queries,
|
||||
);
|
||||
return answer;
|
||||
});
|
||||
export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
'fetchAnswer',
|
||||
async ({ question }, { dispatch, getState }) => {
|
||||
const state = getState() as RootState;
|
||||
if (state.preference) {
|
||||
if (API_STREAMING) {
|
||||
await fetchAnswerSteaming(
|
||||
question,
|
||||
state.preference.apiKey,
|
||||
state.preference.selectedDocs!,
|
||||
(event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
// check if the 'end' event has been received
|
||||
if (data.type === 'end') {
|
||||
// set status to 'idle'
|
||||
dispatch(conversationSlice.actions.setStatus('idle'));
|
||||
} else {
|
||||
const result = data.answer;
|
||||
dispatch(
|
||||
updateStreamingQuery({
|
||||
index: state.conversation.queries.length - 1,
|
||||
query: { response: result },
|
||||
}),
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
} else {
|
||||
const answer = await fetchAnswerApi(
|
||||
question,
|
||||
state.preference.apiKey,
|
||||
state.preference.selectedDocs!,
|
||||
state.conversation.queries,
|
||||
);
|
||||
if (answer) {
|
||||
dispatch(
|
||||
updateQuery({
|
||||
index: state.conversation.queries.length - 1,
|
||||
query: { response: answer.answer },
|
||||
}),
|
||||
);
|
||||
dispatch(conversationSlice.actions.setStatus('idle'));
|
||||
}
|
||||
}
|
||||
}
|
||||
return { answer: '', query: question, result: '' };
|
||||
},
|
||||
);
|
||||
|
||||
export const conversationSlice = createSlice({
|
||||
name: 'conversation',
|
||||
@@ -31,6 +67,21 @@ export const conversationSlice = createSlice({
|
||||
addQuery(state, action: PayloadAction<Query>) {
|
||||
state.queries.push(action.payload);
|
||||
},
|
||||
updateStreamingQuery(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; query: Partial<Query> }>,
|
||||
) {
|
||||
const index = action.payload.index;
|
||||
if (action.payload.query.response) {
|
||||
state.queries[index].response =
|
||||
(state.queries[index].response || '') + action.payload.query.response;
|
||||
} else {
|
||||
state.queries[index] = {
|
||||
...state.queries[index],
|
||||
...action.payload.query,
|
||||
};
|
||||
}
|
||||
},
|
||||
updateQuery(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; query: Partial<Query> }>,
|
||||
@@ -41,17 +92,15 @@ export const conversationSlice = createSlice({
|
||||
...action.payload.query,
|
||||
};
|
||||
},
|
||||
setStatus(state, action: PayloadAction<Status>) {
|
||||
state.status = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder
|
||||
.addCase(fetchAnswer.pending, (state) => {
|
||||
state.status = 'loading';
|
||||
})
|
||||
.addCase(fetchAnswer.fulfilled, (state, action) => {
|
||||
state.status = 'idle';
|
||||
state.queries[state.queries.length - 1].response =
|
||||
action.payload.answer;
|
||||
})
|
||||
.addCase(fetchAnswer.rejected, (state, action) => {
|
||||
state.status = 'failed';
|
||||
state.queries[state.queries.length - 1].error =
|
||||
@@ -66,5 +115,6 @@ export const selectQueries = (state: RootState) => state.conversation.queries;
|
||||
|
||||
export const selectStatus = (state: RootState) => state.conversation.status;
|
||||
|
||||
export const { addQuery, updateQuery } = conversationSlice.actions;
|
||||
export const { addQuery, updateQuery, updateStreamingQuery } =
|
||||
conversationSlice.actions;
|
||||
export default conversationSlice.reducer;
|
||||
|
||||
Reference in New Issue
Block a user