mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
@@ -27,6 +27,7 @@ from langchain.prompts.chat import (
|
||||
)
|
||||
from pymongo import MongoClient
|
||||
from werkzeug.utils import secure_filename
|
||||
from langchain.llms import GPT4All
|
||||
|
||||
from core.settings import settings
|
||||
from error import bad_request
|
||||
@@ -196,6 +197,8 @@ def api_answer():
|
||||
llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key)
|
||||
elif settings.LLM_NAME == "cohere":
|
||||
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
|
||||
elif settings.LLM_NAME == "gpt4all":
|
||||
llm = GPT4All(model=settings.MODEL_PATH)
|
||||
else:
|
||||
raise ValueError("unknown LLM model")
|
||||
|
||||
@@ -211,6 +214,19 @@ def api_answer():
|
||||
# result = chain({"question": question, "chat_history": chat_history})
|
||||
# generate async with async generate method
|
||||
result = run_async_chain(chain, question, chat_history)
|
||||
elif settings.LLM_NAME == "gpt4all":
|
||||
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
||||
doc_chain = load_qa_chain(llm, chain_type="map_reduce", combine_prompt=p_chat_combine)
|
||||
chain = ConversationalRetrievalChain(
|
||||
retriever=docsearch.as_retriever(k=2),
|
||||
question_generator=question_generator,
|
||||
combine_docs_chain=doc_chain,
|
||||
)
|
||||
chat_history = []
|
||||
# result = chain({"question": question, "chat_history": chat_history})
|
||||
# generate async with async generate method
|
||||
result = run_async_chain(chain, question, chat_history)
|
||||
|
||||
else:
|
||||
qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce",
|
||||
combine_prompt=chat_combine_template, question_prompt=q_prompt)
|
||||
|
||||
@@ -9,6 +9,7 @@ class Settings(BaseSettings):
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MODEL_PATH: str = "./models/gpt4all-model.bin"
|
||||
|
||||
API_URL: str = "http://localhost:5001" # backend url for celery worker
|
||||
|
||||
|
||||
@@ -26,10 +26,12 @@ ecdsa==0.18.0
|
||||
entrypoints==0.4
|
||||
faiss-cpu==1.7.3
|
||||
filelock==3.9.0
|
||||
Flask==2.3.2
|
||||
Flask==2.2.3
|
||||
Flask-Cors==3.0.10
|
||||
frozenlist==1.3.3
|
||||
geojson==2.5.0
|
||||
greenlet==2.0.2
|
||||
gpt4all==0.1.7
|
||||
hub==3.0.1
|
||||
huggingface-hub==0.12.1
|
||||
humbug==0.2.8
|
||||
@@ -39,7 +41,8 @@ Jinja2==3.1.2
|
||||
jmespath==1.0.1
|
||||
joblib==1.2.0
|
||||
kombu==5.2.4
|
||||
langchain==0.0.126
|
||||
langchain==0.0.179
|
||||
loguru==0.6.0
|
||||
lxml==4.9.2
|
||||
MarkupSafe==2.1.2
|
||||
marshmallow==3.19.0
|
||||
|
||||
Reference in New Issue
Block a user