From f8910ba13637749f92609b9aa1cbdcea53b34ec7 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 31 May 2023 23:47:16 +0100 Subject: [PATCH] Added history in streaming convo + fixed little bug with message margins on loading state --- application/app.py | 27 +++++++++++++------ application/core/settings.py | 1 + frontend/src/conversation/Conversation.tsx | 6 +---- frontend/src/conversation/conversationApi.ts | 3 ++- .../src/conversation/conversationSlice.ts | 1 + 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/application/app.py b/application/app.py index d0ecb3b9..8441febd 100644 --- a/application/app.py +++ b/application/app.py @@ -151,16 +151,25 @@ def home(): def complete_stream(question, docsearch, chat_history, api_key): openai.api_key = api_key + llm = ChatOpenAI(openai_api_key=api_key) docs = docsearch.similarity_search(question, k=2) # join all page_content together with a newline docs_together = "\n".join([doc.page_content for doc in docs]) - - # swap {summaries} in chat_combine_template with the summaries from the docs p_chat_combine = chat_combine_template.replace("{summaries}", docs_together) - completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[ - {"role": "system", "content": p_chat_combine}, - {"role": "user", "content": question}, - ], stream=True, max_tokens=1000, temperature=0) + messages_combine = [{"role": "system", "content": p_chat_combine}] + if len(chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + chat_history.reverse() + for i in chat_history: + if "prompt" in i and "response" in i: + tokens_batch = llm.get_num_tokens(i["prompt"]) + llm.get_num_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": question}) + completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages_combine, stream=True, max_tokens=1000, temperature=0) for line in completion: if 'content' in line['choices'][0]['delta']: @@ -175,6 +184,9 @@ def stream(): # get parameter from url question question = request.args.get('question') history = request.args.get('history') + # history to json object from string + history = json.loads(history) + # check if active_docs is set if not api_key_set: @@ -227,13 +239,12 @@ def api_answer(): messages_combine = [SystemMessagePromptTemplate.from_template(chat_combine_template)] if history: tokens_current_history = 0 - tokens_max_history = 1000 #count tokens in history history.reverse() for i in history: if "prompt" in i and "response" in i: tokens_batch = llm.get_num_tokens(i["prompt"]) + llm.get_num_tokens(i["response"]) - if tokens_current_history + tokens_batch < tokens_max_history: + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: tokens_current_history += tokens_batch messages_combine.append(HumanMessagePromptTemplate.from_template(i["prompt"])) messages_combine.append(AIMessagePromptTemplate.from_template(i["response"])) diff --git a/application/core/settings.py b/application/core/settings.py index 3c0672da..543c4cf4 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -10,6 +10,7 @@ class Settings(BaseSettings): CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MODEL_PATH: str = "./models/gpt4all-model.bin" + TOKENS_MAX_HISTORY: int = 150 API_URL: str = "http://localhost:5001" # backend url for celery worker diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx index fd3f3e93..471a5cd6 100644 --- a/frontend/src/conversation/Conversation.tsx +++ b/frontend/src/conversation/Conversation.tsx @@ -79,11 +79,7 @@ export default function Conversation() { = [], onEvent: (event: MessageEvent) => void, ): Promise { let namePath = selectedDocs.name; @@ -86,8 +87,8 @@ export function fetchAnswerSteaming( url.searchParams.append('question', question); url.searchParams.append('api_key', apiKey); url.searchParams.append('embeddings_key', apiKey); - url.searchParams.append('history', localStorage.getItem('chatHistory')); url.searchParams.append('active_docs', docPath); + url.searchParams.append('history', JSON.stringify(history)); const eventSource = new EventSource(url.href); diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 37a7b0e7..70fa1a81 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -20,6 +20,7 @@ export const fetchAnswer = createAsyncThunk( question, state.preference.apiKey, state.preference.selectedDocs!, + state.conversation.queries, (event) => { const data = JSON.parse(event.data);