mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
switching between llms
This commit is contained in:
@@ -1,56 +1,68 @@
|
||||
import os
|
||||
import pickle
|
||||
import dotenv
|
||||
import datetime
|
||||
from flask import Flask, request, render_template
|
||||
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
import faiss
|
||||
|
||||
import dotenv
|
||||
import requests
|
||||
from flask import Flask, request, render_template
|
||||
from langchain import FAISS
|
||||
from langchain import OpenAI, VectorDBQA, HuggingFaceHub, Cohere
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings, CohereEmbeddings, HuggingFaceInstructEmbeddings
|
||||
from langchain.prompts import PromptTemplate
|
||||
import requests
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
# from manifest import Manifest
|
||||
# from langchain.llms.manifest import ManifestWrapper
|
||||
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
|
||||
# manifest = Manifest(
|
||||
# client_name = "huggingface",
|
||||
# client_connection = "http://127.0.0.1:5000"
|
||||
# )
|
||||
if os.getenv("LLM_NAME") is not None:
|
||||
llm_choice = os.getenv("LLM_NAME")
|
||||
else:
|
||||
llm_choice = "openai"
|
||||
|
||||
if os.getenv("EMBEDDINGS_NAME") is not None:
|
||||
embeddings_choice = os.getenv("EMBEDDINGS_NAME")
|
||||
else:
|
||||
embeddings_choice = "openai_text-embedding-ada-002"
|
||||
|
||||
|
||||
|
||||
if llm_choice == "manifest":
|
||||
from manifest import Manifest
|
||||
from langchain.llms.manifest import ManifestWrapper
|
||||
|
||||
manifest = Manifest(
|
||||
client_name="huggingface",
|
||||
client_connection="http://127.0.0.1:5000"
|
||||
)
|
||||
|
||||
# Redirect PosixPath to WindowsPath on Windows
|
||||
import platform
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import pathlib
|
||||
|
||||
temp = pathlib.PosixPath
|
||||
pathlib.PosixPath = pathlib.WindowsPath
|
||||
|
||||
# loading the .env file
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
with open("combine_prompt.txt", "r") as f:
|
||||
template = f.read()
|
||||
|
||||
# check if OPENAI_API_KEY is set
|
||||
if os.getenv("OPENAI_API_KEY") is not None:
|
||||
if os.getenv("API_KEY") is not None:
|
||||
api_key_set = True
|
||||
|
||||
else:
|
||||
api_key_set = False
|
||||
|
||||
|
||||
if os.getenv("EMBEDDINGS_KEY") is not None:
|
||||
embeddings_key_set = True
|
||||
else:
|
||||
embeddings_key_set = False
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
return render_template("index.html", api_key_set=api_key_set)
|
||||
return render_template("index.html", api_key_set=api_key_set, llm_choice=llm_choice,
|
||||
embeddings_choice=embeddings_choice)
|
||||
|
||||
|
||||
@app.route("/api/answer", methods=["POST"])
|
||||
@@ -60,7 +72,14 @@ def api_answer():
|
||||
if not api_key_set:
|
||||
api_key = data["api_key"]
|
||||
else:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
api_key = os.getenv("API_KEY")
|
||||
if not embeddings_key_set:
|
||||
embeddings_key = data["embeddings_key"]
|
||||
else:
|
||||
embeddings_key = os.getenv("EMBEDDINGS_KEY")
|
||||
|
||||
print(embeddings_key)
|
||||
print(api_key)
|
||||
|
||||
# check if the vectorstore is set
|
||||
if "active_docs" in data:
|
||||
@@ -70,24 +89,32 @@ def api_answer():
|
||||
else:
|
||||
vectorstore = ""
|
||||
|
||||
|
||||
# loading the index and the store and the prompt template
|
||||
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=api_key))
|
||||
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
|
||||
if embeddings_choice == "openai_text-embedding-ada-002":
|
||||
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key))
|
||||
elif embeddings_choice == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
|
||||
elif embeddings_choice == "huggingface_hkunlp/instructor-large":
|
||||
docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
|
||||
elif embeddings_choice == "cohere_medium":
|
||||
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
|
||||
|
||||
# create a prompt template
|
||||
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template)
|
||||
# create a chain with the prompt template and the store
|
||||
|
||||
if llm_choice == "openai":
|
||||
llm = OpenAI(openai_api_key=api_key, temperature=0)
|
||||
elif llm_choice == "manifest":
|
||||
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
|
||||
elif llm_choice == "huggingface":
|
||||
llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key)
|
||||
elif llm_choice == "cohere":
|
||||
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
|
||||
|
||||
#llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
|
||||
llm = OpenAI(openai_api_key=api_key, temperature=0)
|
||||
#llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key)
|
||||
# llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
|
||||
|
||||
qa_chain = load_qa_chain(llm = llm, chain_type="map_reduce",
|
||||
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
|
||||
combine_prompt=c_prompt)
|
||||
|
||||
|
||||
chain = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch, k=2)
|
||||
|
||||
# fetch the answer
|
||||
@@ -105,6 +132,7 @@ def api_answer():
|
||||
# }
|
||||
return result
|
||||
|
||||
|
||||
@app.route("/api/docs_check", methods=["POST"])
|
||||
def check_docs():
|
||||
# check if docs exist in a vectorstore folder
|
||||
@@ -128,7 +156,8 @@ def check_docs():
|
||||
with open(vectorstore + "faiss_store.pkl", "wb") as f:
|
||||
f.write(r.content)
|
||||
|
||||
return {"status": 'loaded'}
|
||||
return {"status": 'loaded'}
|
||||
|
||||
|
||||
# handling CORS
|
||||
@app.after_request
|
||||
|
||||
Reference in New Issue
Block a user