Small fixes + polishing

This commit is contained in:
Alex
2023-03-21 22:16:09 +00:00
parent ce579293fb
commit 92993ee105
3 changed files with 31 additions and 23 deletions

View File

@@ -1,13 +1,16 @@
import datetime
import json
import os
import traceback
import datetime
import dotenv
import requests
from flask import Flask, request, render_template, redirect, send_from_directory, jsonify
from celery import Celery
from celery.result import AsyncResult
from flask import Flask, request, render_template, send_from_directory, jsonify
from langchain import FAISS
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
from langchain.chains import ChatVectorDBChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, \
@@ -18,17 +21,12 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from pymongo import MongoClient
from werkzeug.utils import secure_filename
from error import bad_request
from werkzeug.utils import secure_filename
from pymongo import MongoClient
from celery import Celery, current_task
from celery.result import AsyncResult
from worker import ingest_worker
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
if os.getenv("LLM_NAME") is not None:
@@ -98,11 +96,13 @@ mongo = MongoClient(app.config['MONGO_URI'])
db = mongo["docsgpt"]
vectors_collection = db["vectors"]
@celery.task(bind=True)
def ingest(self, directory, formats, name_job, filename, user):
resp = ingest_worker(self, directory, formats, name_job, filename, user)
return resp
@app.route("/")
def home():
return render_template("index.html", api_key_set=api_key_set, llm_choice=llm_choice,
@@ -163,7 +163,7 @@ def api_answer():
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
template_format="jinja2")
if llm_choice == "openai_chat":
#llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4")
# llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4")
llm = ChatOpenAI(openai_api_key=api_key)
messages_combine = [
SystemMessagePromptTemplate.from_template(chat_combine_template),
@@ -185,12 +185,14 @@ def api_answer():
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
if llm_choice == "openai_chat":
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
k=4,
chain_type_kwargs={"question_prompt": p_chat_reduce,
"combine_prompt": p_chat_combine}
)
result = chain({"query": question})
chain = ChatVectorDBChain.from_llm(
llm=llm,
vectorstore=docsearch,
prompt=p_chat_combine,
qa_prompt=p_chat_reduce,
top_k_docs_for_context=3,
return_source_documents=False)
result = chain({"question": question, "chat_history": []})
else:
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
combine_prompt=c_prompt, question_prompt=q_prompt)
@@ -200,7 +202,8 @@ def api_answer():
print(result)
# some formatting for the frontend
result['answer'] = result['result']
if "result" in result:
result['answer'] = result['result']
result['answer'] = result['answer'].replace("\\n", "\n")
try:
result['answer'] = result['answer'].split("SOURCES:")[0]
@@ -275,6 +278,7 @@ def api_feedback():
)
return {"status": 'ok'}
@app.route('/api/combine', methods=['GET'])
def combined_json():
user = 'local'
@@ -302,8 +306,9 @@ def combined_json():
index['location'] = "remote"
data.append(index)
return jsonify(data)
@app.route('/api/upload', methods=['POST'])
def upload_file():
"""Upload a file to get vectorized and indexed."""
@@ -321,7 +326,6 @@ def upload_file():
if file.filename == '':
return {"status": 'no file name'}
if file:
filename = secure_filename(file.filename)
# save dir
@@ -338,6 +342,7 @@ def upload_file():
else:
return {"status": 'error'}
@app.route('/api/task_status', methods=['GET'])
def task_status():
"""Get celery job status."""
@@ -346,6 +351,7 @@ def task_status():
task_meta = task.info
return {"status": task.status, "result": task_meta}
### Backgound task api
@app.route('/api/upload_index', methods=['POST'])
def upload_index_files():
@@ -388,7 +394,6 @@ def upload_index_files():
return {"status": 'ok'}
@app.route('/api/download', methods=['get'])
def download_file():
user = secure_filename(request.args.get('user'))
@@ -397,6 +402,7 @@ def download_file():
save_dir = os.path.join(app.config['UPLOAD_FOLDER'], user, job_name)
return send_from_directory(save_dir, filename, as_attachment=True)
@app.route('/api/delete_old', methods=['get'])
def delete_old():
"""Delete old indexes."""
@@ -417,6 +423,7 @@ def delete_old():
pass
return {"status": 'ok'}
# handling CORS
@app.after_request
def after_request(response):