From 92993ee1059acb21cfdebb8a57d5c81b5e945b3e Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 21 Mar 2023 22:16:09 +0000 Subject: [PATCH] Small fixes + polishing --- application/app.py | 47 +++++++++++++++++++++--------------- application/requirements.txt | 5 ++-- application/worker.py | 2 +- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/application/app.py b/application/app.py index dfce4970..2dafefc0 100644 --- a/application/app.py +++ b/application/app.py @@ -1,13 +1,16 @@ +import datetime import json import os import traceback -import datetime import dotenv import requests -from flask import Flask, request, render_template, redirect, send_from_directory, jsonify +from celery import Celery +from celery.result import AsyncResult +from flask import Flask, request, render_template, send_from_directory, jsonify from langchain import FAISS from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI +from langchain.chains import ChatVectorDBChain from langchain.chains.question_answering import load_qa_chain from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, \ @@ -18,17 +21,12 @@ from langchain.prompts.chat import ( SystemMessagePromptTemplate, HumanMessagePromptTemplate, ) +from pymongo import MongoClient +from werkzeug.utils import secure_filename from error import bad_request -from werkzeug.utils import secure_filename -from pymongo import MongoClient - -from celery import Celery, current_task -from celery.result import AsyncResult - from worker import ingest_worker - # os.environ["LANGCHAIN_HANDLER"] = "langchain" if os.getenv("LLM_NAME") is not None: @@ -98,11 +96,13 @@ mongo = MongoClient(app.config['MONGO_URI']) db = mongo["docsgpt"] vectors_collection = db["vectors"] + @celery.task(bind=True) def ingest(self, directory, formats, name_job, filename, user): resp = ingest_worker(self, directory, formats, name_job, filename, user) return resp + @app.route("/") def home(): return render_template("index.html", api_key_set=api_key_set, llm_choice=llm_choice, @@ -163,7 +163,7 @@ def api_answer(): q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest, template_format="jinja2") if llm_choice == "openai_chat": - #llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4") + # llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4") llm = ChatOpenAI(openai_api_key=api_key) messages_combine = [ SystemMessagePromptTemplate.from_template(chat_combine_template), @@ -185,12 +185,14 @@ def api_answer(): llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key) if llm_choice == "openai_chat": - chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch, - k=4, - chain_type_kwargs={"question_prompt": p_chat_reduce, - "combine_prompt": p_chat_combine} - ) - result = chain({"query": question}) + chain = ChatVectorDBChain.from_llm( + llm=llm, + vectorstore=docsearch, + prompt=p_chat_combine, + qa_prompt=p_chat_reduce, + top_k_docs_for_context=3, + return_source_documents=False) + result = chain({"question": question, "chat_history": []}) else: qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce", combine_prompt=c_prompt, question_prompt=q_prompt) @@ -200,7 +202,8 @@ def api_answer(): print(result) # some formatting for the frontend - result['answer'] = result['result'] + if "result" in result: + result['answer'] = result['result'] result['answer'] = result['answer'].replace("\\n", "\n") try: result['answer'] = result['answer'].split("SOURCES:")[0] @@ -275,6 +278,7 @@ def api_feedback(): ) return {"status": 'ok'} + @app.route('/api/combine', methods=['GET']) def combined_json(): user = 'local' @@ -302,8 +306,9 @@ def combined_json(): index['location'] = "remote" data.append(index) - return jsonify(data) + + @app.route('/api/upload', methods=['POST']) def upload_file(): """Upload a file to get vectorized and indexed.""" @@ -321,7 +326,6 @@ def upload_file(): if file.filename == '': return {"status": 'no file name'} - if file: filename = secure_filename(file.filename) # save dir @@ -338,6 +342,7 @@ def upload_file(): else: return {"status": 'error'} + @app.route('/api/task_status', methods=['GET']) def task_status(): """Get celery job status.""" @@ -346,6 +351,7 @@ def task_status(): task_meta = task.info return {"status": task.status, "result": task_meta} + ### Backgound task api @app.route('/api/upload_index', methods=['POST']) def upload_index_files(): @@ -388,7 +394,6 @@ def upload_index_files(): return {"status": 'ok'} - @app.route('/api/download', methods=['get']) def download_file(): user = secure_filename(request.args.get('user')) @@ -397,6 +402,7 @@ def download_file(): save_dir = os.path.join(app.config['UPLOAD_FOLDER'], user, job_name) return send_from_directory(save_dir, filename, as_attachment=True) + @app.route('/api/delete_old', methods=['get']) def delete_old(): """Delete old indexes.""" @@ -417,6 +423,7 @@ def delete_old(): pass return {"status": 'ok'} + # handling CORS @app.after_request def after_request(response): diff --git a/application/requirements.txt b/application/requirements.txt index f4f3539e..5203b4d6 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -38,7 +38,7 @@ Jinja2==3.1.2 jmespath==1.0.1 joblib==1.2.0 kombu==5.2.4 -langchain==0.0.103 +langchain==0.0.118 lxml==4.9.2 MarkupSafe==2.1.2 marshmallow==3.19.0 @@ -64,12 +64,13 @@ pycryptodomex==3.17 pydantic==1.10.5 PyJWT==2.6.0 pymongo==4.3.3 +PyPDF2==3.0.1 python-dateutil==2.8.2 python-dotenv==1.0.0 python-jose==3.3.0 pytz==2022.7.1 PyYAML==6.0 -redis==4.5.1 +redis==4.5.2 regex==2022.10.31 requests==2.28.2 retry==0.9.2 diff --git a/application/worker.py b/application/worker.py index 268b829a..90bc9bf8 100644 --- a/application/worker.py +++ b/application/worker.py @@ -36,7 +36,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): sample = False token_check = True min_tokens = 150 - max_tokens = 2000 + max_tokens = 1250 full_path = directory + '/' + user + '/' + name_job # check if API_URL env variable is set if not os.environ.get('API_URL'):