mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Small fixes + polishing
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
import dotenv
|
||||
import requests
|
||||
from flask import Flask, request, render_template, redirect, send_from_directory, jsonify
|
||||
from celery import Celery
|
||||
from celery.result import AsyncResult
|
||||
from flask import Flask, request, render_template, send_from_directory, jsonify
|
||||
from langchain import FAISS
|
||||
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
|
||||
from langchain.chains import ChatVectorDBChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, \
|
||||
@@ -18,17 +21,12 @@ from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from pymongo import MongoClient
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from error import bad_request
|
||||
from werkzeug.utils import secure_filename
|
||||
from pymongo import MongoClient
|
||||
|
||||
from celery import Celery, current_task
|
||||
from celery.result import AsyncResult
|
||||
|
||||
from worker import ingest_worker
|
||||
|
||||
|
||||
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
|
||||
if os.getenv("LLM_NAME") is not None:
|
||||
@@ -98,11 +96,13 @@ mongo = MongoClient(app.config['MONGO_URI'])
|
||||
db = mongo["docsgpt"]
|
||||
vectors_collection = db["vectors"]
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest(self, directory, formats, name_job, filename, user):
|
||||
resp = ingest_worker(self, directory, formats, name_job, filename, user)
|
||||
return resp
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
return render_template("index.html", api_key_set=api_key_set, llm_choice=llm_choice,
|
||||
@@ -163,7 +163,7 @@ def api_answer():
|
||||
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
|
||||
template_format="jinja2")
|
||||
if llm_choice == "openai_chat":
|
||||
#llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4")
|
||||
# llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4")
|
||||
llm = ChatOpenAI(openai_api_key=api_key)
|
||||
messages_combine = [
|
||||
SystemMessagePromptTemplate.from_template(chat_combine_template),
|
||||
@@ -185,12 +185,14 @@ def api_answer():
|
||||
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
|
||||
|
||||
if llm_choice == "openai_chat":
|
||||
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
|
||||
k=4,
|
||||
chain_type_kwargs={"question_prompt": p_chat_reduce,
|
||||
"combine_prompt": p_chat_combine}
|
||||
)
|
||||
result = chain({"query": question})
|
||||
chain = ChatVectorDBChain.from_llm(
|
||||
llm=llm,
|
||||
vectorstore=docsearch,
|
||||
prompt=p_chat_combine,
|
||||
qa_prompt=p_chat_reduce,
|
||||
top_k_docs_for_context=3,
|
||||
return_source_documents=False)
|
||||
result = chain({"question": question, "chat_history": []})
|
||||
else:
|
||||
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
|
||||
combine_prompt=c_prompt, question_prompt=q_prompt)
|
||||
@@ -200,7 +202,8 @@ def api_answer():
|
||||
print(result)
|
||||
|
||||
# some formatting for the frontend
|
||||
result['answer'] = result['result']
|
||||
if "result" in result:
|
||||
result['answer'] = result['result']
|
||||
result['answer'] = result['answer'].replace("\\n", "\n")
|
||||
try:
|
||||
result['answer'] = result['answer'].split("SOURCES:")[0]
|
||||
@@ -275,6 +278,7 @@ def api_feedback():
|
||||
)
|
||||
return {"status": 'ok'}
|
||||
|
||||
|
||||
@app.route('/api/combine', methods=['GET'])
|
||||
def combined_json():
|
||||
user = 'local'
|
||||
@@ -302,8 +306,9 @@ def combined_json():
|
||||
index['location'] = "remote"
|
||||
data.append(index)
|
||||
|
||||
|
||||
return jsonify(data)
|
||||
|
||||
|
||||
@app.route('/api/upload', methods=['POST'])
|
||||
def upload_file():
|
||||
"""Upload a file to get vectorized and indexed."""
|
||||
@@ -321,7 +326,6 @@ def upload_file():
|
||||
if file.filename == '':
|
||||
return {"status": 'no file name'}
|
||||
|
||||
|
||||
if file:
|
||||
filename = secure_filename(file.filename)
|
||||
# save dir
|
||||
@@ -338,6 +342,7 @@ def upload_file():
|
||||
else:
|
||||
return {"status": 'error'}
|
||||
|
||||
|
||||
@app.route('/api/task_status', methods=['GET'])
|
||||
def task_status():
|
||||
"""Get celery job status."""
|
||||
@@ -346,6 +351,7 @@ def task_status():
|
||||
task_meta = task.info
|
||||
return {"status": task.status, "result": task_meta}
|
||||
|
||||
|
||||
### Backgound task api
|
||||
@app.route('/api/upload_index', methods=['POST'])
|
||||
def upload_index_files():
|
||||
@@ -388,7 +394,6 @@ def upload_index_files():
|
||||
return {"status": 'ok'}
|
||||
|
||||
|
||||
|
||||
@app.route('/api/download', methods=['get'])
|
||||
def download_file():
|
||||
user = secure_filename(request.args.get('user'))
|
||||
@@ -397,6 +402,7 @@ def download_file():
|
||||
save_dir = os.path.join(app.config['UPLOAD_FOLDER'], user, job_name)
|
||||
return send_from_directory(save_dir, filename, as_attachment=True)
|
||||
|
||||
|
||||
@app.route('/api/delete_old', methods=['get'])
|
||||
def delete_old():
|
||||
"""Delete old indexes."""
|
||||
@@ -417,6 +423,7 @@ def delete_old():
|
||||
pass
|
||||
return {"status": 'ok'}
|
||||
|
||||
|
||||
# handling CORS
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
|
||||
Reference in New Issue
Block a user