mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
@@ -1,19 +1,25 @@
|
||||
FROM python:3.10-slim-bullseye as builder
|
||||
FROM python:3.11-slim-bullseye as builder
|
||||
|
||||
# Tiktoken requires Rust toolchain, so build it in a separate stage
|
||||
RUN apt-get update && apt-get install -y gcc curl
|
||||
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && apt-get install --reinstall libc6-dev -y
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
RUN pip install --upgrade pip && pip install tiktoken==0.3.3
|
||||
RUN pip install --upgrade pip && pip install tiktoken==0.5.2
|
||||
COPY requirements.txt .
|
||||
RUN pip install -r requirements.txt
|
||||
RUN apt-get install -y wget unzip
|
||||
RUN wget https://d3dg1063dc54p9.cloudfront.net/models/embeddings/mpnet-base-v2.zip
|
||||
RUN unzip mpnet-base-v2.zip -d model
|
||||
RUN rm mpnet-base-v2.zip
|
||||
|
||||
FROM python:3.10-slim-bullseye
|
||||
FROM python:3.11-slim-bullseye
|
||||
|
||||
# Copy pre-built packages and binaries from builder stage
|
||||
COPY --from=builder /usr/local/ /usr/local/
|
||||
|
||||
WORKDIR /app
|
||||
COPY --from=builder /model /app/model
|
||||
|
||||
COPY . /app/application
|
||||
ENV FLASK_APP=app.py
|
||||
ENV FLASK_DEBUG=true
|
||||
|
||||
@@ -25,30 +25,30 @@ mongo = MongoClient(settings.MONGO_URI)
|
||||
db = mongo["docsgpt"]
|
||||
conversations_collection = db["conversations"]
|
||||
vectors_collection = db["vectors"]
|
||||
prompts_collection = db["prompts"]
|
||||
answer = Blueprint('answer', __name__)
|
||||
|
||||
if settings.LLM_NAME == "gpt4":
|
||||
gpt_model = 'gpt-4'
|
||||
elif settings.LLM_NAME == "anthropic":
|
||||
gpt_model = 'claude-2'
|
||||
else:
|
||||
gpt_model = 'gpt-3.5-turbo'
|
||||
|
||||
# load the prompts
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
with open(os.path.join(current_dir, "prompts", "combine_prompt.txt"), "r") as f:
|
||||
template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "combine_prompt_hist.txt"), "r") as f:
|
||||
template_hist = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "question_prompt.txt"), "r") as f:
|
||||
template_quest = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_prompt.txt"), "r") as f:
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
||||
chat_combine_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
|
||||
chat_reduce_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
||||
chat_combine_creative = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
||||
chat_combine_strict = f.read()
|
||||
|
||||
api_key_set = settings.API_KEY is not None
|
||||
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
|
||||
|
||||
@@ -77,11 +77,10 @@ def run_async_chain(chain, question, chat_history):
|
||||
|
||||
def get_vectorstore(data):
|
||||
if "active_docs" in data:
|
||||
if data["active_docs"].split("/")[0] == "local":
|
||||
if data["active_docs"].split("/")[1] == "default":
|
||||
if data["active_docs"].split("/")[0] == "default":
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = "indexes/" + data["active_docs"]
|
||||
elif data["active_docs"].split("/")[0] == "local":
|
||||
vectorstore = "indexes/" + data["active_docs"]
|
||||
else:
|
||||
vectorstore = "vectors/" + data["active_docs"]
|
||||
if data["active_docs"] == "default":
|
||||
@@ -92,47 +91,35 @@ def get_vectorstore(data):
|
||||
return vectorstore
|
||||
|
||||
|
||||
# def get_docsearch(vectorstore, embeddings_key):
|
||||
# if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002":
|
||||
# if is_azure_configured():
|
||||
# os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
# openai_embeddings = OpenAIEmbeddings(model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME)
|
||||
# else:
|
||||
# openai_embeddings = OpenAIEmbeddings(openai_api_key=embeddings_key)
|
||||
# docsearch = FAISS.load_local(vectorstore, openai_embeddings)
|
||||
# elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
# docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
|
||||
# elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
|
||||
# docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
|
||||
# elif settings.EMBEDDINGS_NAME == "cohere_medium":
|
||||
# docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
|
||||
# return docsearch
|
||||
|
||||
|
||||
def is_azure_configured():
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
|
||||
|
||||
def complete_stream(question, docsearch, chat_history, api_key, conversation_id):
|
||||
def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id):
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
|
||||
|
||||
|
||||
if prompt_id == 'default':
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == 'creative':
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == 'strict':
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
|
||||
docs = docsearch.search(question, k=2)
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
p_chat_combine = chat_combine_template.replace("{summaries}", docs_together)
|
||||
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
source_log_docs = []
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
data = json.dumps({"type": "source", "doc": doc.page_content, "metadata": doc.metadata})
|
||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
||||
else:
|
||||
data = json.dumps({"type": "source", "doc": doc.page_content})
|
||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
||||
yield f"data:{data}\n\n"
|
||||
|
||||
if len(chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
@@ -199,6 +186,10 @@ def stream():
|
||||
# history to json object from string
|
||||
history = json.loads(history)
|
||||
conversation_id = data["conversation_id"]
|
||||
if 'prompt_id' in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = 'default'
|
||||
|
||||
# check if active_docs is set
|
||||
|
||||
@@ -219,6 +210,7 @@ def stream():
|
||||
return Response(
|
||||
complete_stream(question, docsearch,
|
||||
chat_history=history, api_key=api_key,
|
||||
prompt_id=prompt_id,
|
||||
conversation_id=conversation_id), mimetype="text/event-stream"
|
||||
)
|
||||
|
||||
@@ -241,6 +233,19 @@ def api_answer():
|
||||
embeddings_key = data["embeddings_key"]
|
||||
else:
|
||||
embeddings_key = settings.EMBEDDINGS_KEY
|
||||
if 'prompt_id' in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = 'default'
|
||||
|
||||
if prompt_id == 'default':
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == 'creative':
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == 'strict':
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
|
||||
# use try and except to check for exception
|
||||
try:
|
||||
@@ -258,7 +263,7 @@ def api_answer():
|
||||
docs = docsearch.search(question, k=2)
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
p_chat_combine = chat_combine_template.replace("{summaries}", docs_together)
|
||||
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
source_log_docs = []
|
||||
for doc in docs:
|
||||
@@ -335,3 +340,35 @@ def api_answer():
|
||||
traceback.print_exc()
|
||||
print(str(e))
|
||||
return bad_request(500, str(e))
|
||||
|
||||
|
||||
@answer.route("/api/search", methods=["POST"])
|
||||
def api_search():
|
||||
data = request.get_json()
|
||||
# get parameter from url question
|
||||
question = data["question"]
|
||||
|
||||
if not embeddings_key_set:
|
||||
if "embeddings_key" in data:
|
||||
embeddings_key = data["embeddings_key"]
|
||||
else:
|
||||
embeddings_key = settings.EMBEDDINGS_KEY
|
||||
else:
|
||||
embeddings_key = settings.EMBEDDINGS_KEY
|
||||
if "active_docs" in data:
|
||||
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
|
||||
else:
|
||||
vectorstore = ""
|
||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key)
|
||||
|
||||
docs = docsearch.search(question, k=2)
|
||||
|
||||
source_log_docs = []
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
||||
else:
|
||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
||||
#yield f"data:{data}\n\n"
|
||||
return source_log_docs
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import os
|
||||
from flask import Blueprint, request, jsonify
|
||||
import requests
|
||||
import json
|
||||
from pymongo import MongoClient
|
||||
from bson.objectid import ObjectId
|
||||
from werkzeug.utils import secure_filename
|
||||
import http.client
|
||||
|
||||
from application.api.user.tasks import ingest
|
||||
|
||||
@@ -16,6 +14,8 @@ mongo = MongoClient(settings.MONGO_URI)
|
||||
db = mongo["docsgpt"]
|
||||
conversations_collection = db["conversations"]
|
||||
vectors_collection = db["vectors"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
user = Blueprint('user', __name__)
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -70,20 +70,29 @@ def api_feedback():
|
||||
answer = data["answer"]
|
||||
feedback = data["feedback"]
|
||||
|
||||
print("-" * 5)
|
||||
print("Question: " + question)
|
||||
print("Answer: " + answer)
|
||||
print("Feedback: " + feedback)
|
||||
print("-" * 5)
|
||||
response = requests.post(
|
||||
url="https://86x89umx77.execute-api.eu-west-2.amazonaws.com/docsgpt-feedback",
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
},
|
||||
data=json.dumps({"answer": answer, "question": question, "feedback": feedback}),
|
||||
)
|
||||
return {"status": http.client.responses.get(response.status_code, "ok")}
|
||||
|
||||
feedback_collection.insert_one(
|
||||
{
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"feedback": feedback,
|
||||
}
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@user.route("/api/delete_by_ids", methods=["get"])
|
||||
def delete_by_ids():
|
||||
"""Delete by ID. These are the IDs in the vectorstore"""
|
||||
|
||||
ids = request.args.get("path")
|
||||
if not ids:
|
||||
return {"status": "error"}
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
result = vectors_collection.delete_index(ids=ids)
|
||||
if result:
|
||||
return {"status": "ok"}
|
||||
return {"status": "error"}
|
||||
|
||||
@user.route("/api/delete_old", methods=["get"])
|
||||
def delete_old():
|
||||
@@ -140,7 +149,9 @@ def upload_file():
|
||||
os.makedirs(save_dir)
|
||||
|
||||
file.save(os.path.join(save_dir, filename))
|
||||
task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt"], job_name, filename, user)
|
||||
task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx",
|
||||
".csv", ".epub", ".html", ".mdx"],
|
||||
job_name, filename, user)
|
||||
# task id
|
||||
task_id = task.id
|
||||
return {"status": "ok", "task_id": task_id}
|
||||
@@ -173,7 +184,7 @@ def combined_json():
|
||||
"date": "default",
|
||||
"docLink": "default",
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"location": "remote",
|
||||
}
|
||||
]
|
||||
# structure: name, language, version, description, fullName, date, docLink
|
||||
@@ -230,6 +241,80 @@ def check_docs():
|
||||
|
||||
return {"status": "loaded"}
|
||||
|
||||
@user.route("/api/create_prompt", methods=["POST"])
|
||||
def create_prompt():
|
||||
data = request.get_json()
|
||||
content = data["content"]
|
||||
name = data["name"]
|
||||
if name == "":
|
||||
return {"status": "error"}
|
||||
user = "local"
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": name,
|
||||
"content": content,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
return {"id": new_id}
|
||||
|
||||
@user.route("/api/get_prompts", methods=["GET"])
|
||||
def get_prompts():
|
||||
user = "local"
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
list_prompts = []
|
||||
list_prompts.append({"id": "default", "name": "default", "type": "public"})
|
||||
list_prompts.append({"id": "creative", "name": "creative", "type": "public"})
|
||||
list_prompts.append({"id": "strict", "name": "strict", "type": "public"})
|
||||
for prompt in prompts:
|
||||
list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"})
|
||||
|
||||
return jsonify(list_prompts)
|
||||
|
||||
@user.route("/api/get_single_prompt", methods=["GET"])
|
||||
def get_single_prompt():
|
||||
prompt_id = request.args.get("id")
|
||||
if prompt_id == 'default':
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
||||
chat_combine_template = f.read()
|
||||
return jsonify({"content": chat_combine_template})
|
||||
elif prompt_id == 'creative':
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
||||
chat_reduce_creative = f.read()
|
||||
return jsonify({"content": chat_reduce_creative})
|
||||
elif prompt_id == 'strict':
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return jsonify({"content": chat_reduce_strict})
|
||||
|
||||
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
||||
return jsonify({"content": prompt["content"]})
|
||||
|
||||
@user.route("/api/delete_prompt", methods=["POST"])
|
||||
def delete_prompt():
|
||||
data = request.get_json()
|
||||
id = data["id"]
|
||||
prompts_collection.delete_one(
|
||||
{
|
||||
"_id": ObjectId(id),
|
||||
}
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@user.route("/api/update_prompt", methods=["POST"])
|
||||
def update_prompt_name():
|
||||
data = request.get_json()
|
||||
id = data["id"]
|
||||
name = data["name"]
|
||||
content = data["content"]
|
||||
# check if name is null
|
||||
if name == "":
|
||||
return {"status": "error"}
|
||||
prompts_collection.update_one({"_id": ObjectId(id)},{"$set":{"name":name, "content": content}})
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
LLM_NAME: str = "openai"
|
||||
EMBEDDINGS_NAME: str = "openai_text-embedding-ada-002"
|
||||
LLM_NAME: str = "docsgpt"
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
@@ -18,25 +19,25 @@ class Settings(BaseSettings):
|
||||
|
||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||
|
||||
API_KEY: str = None # LLM api key
|
||||
EMBEDDINGS_KEY: str = None # api key for embeddings (if using openai, just copy API_KEY
|
||||
OPENAI_API_BASE: str = None # azure openai api base url
|
||||
OPENAI_API_VERSION: str = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: str = None # azure deployment name for answering
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: str = None # azure deployment name for embeddings
|
||||
API_KEY: Optional[str] = None # LLM api key
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
|
||||
|
||||
# elasticsearch
|
||||
ELASTIC_CLOUD_ID: str = None # cloud id for elasticsearch
|
||||
ELASTIC_USERNAME: str = None # username for elasticsearch
|
||||
ELASTIC_PASSWORD: str = None # password for elasticsearch
|
||||
ELASTIC_URL: str = None # url for elasticsearch
|
||||
ELASTIC_INDEX: str = "docsgpt" # index name for elasticsearch
|
||||
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
|
||||
ELASTIC_USERNAME: Optional[str] = None # username for elasticsearch
|
||||
ELASTIC_PASSWORD: Optional[str] = None # password for elasticsearch
|
||||
ELASTIC_URL: Optional[str] = None # url for elasticsearch
|
||||
ELASTIC_INDEX: Optional[str] = "docsgpt" # index name for elasticsearch
|
||||
|
||||
# SageMaker config
|
||||
SAGEMAKER_ENDPOINT: str = None # SageMaker endpoint name
|
||||
SAGEMAKER_REGION: str = None # SageMaker region name
|
||||
SAGEMAKER_ACCESS_KEY: str = None # SageMaker access key
|
||||
SAGEMAKER_SECRET_KEY: str = None # SageMaker secret key
|
||||
SAGEMAKER_ENDPOINT: Optional[str] = None # SageMaker endpoint name
|
||||
SAGEMAKER_REGION: Optional[str] = None # SageMaker region name
|
||||
SAGEMAKER_ACCESS_KEY: Optional[str] = None # SageMaker access key
|
||||
SAGEMAKER_SECRET_KEY: Optional[str] = None # SageMaker secret key
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
40
application/llm/anthropic.py
Normal file
40
application/llm/anthropic.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from application.core.settings import settings
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None):
|
||||
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||
self.api_key = api_key or settings.ANTHROPIC_API_KEY # If not provided, use a default from settings
|
||||
self.anthropic = Anthropic(api_key=self.api_key)
|
||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
|
||||
def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
if stream:
|
||||
return self.gen_stream(model, prompt, max_tokens, **kwargs)
|
||||
|
||||
completion = self.anthropic.completions.create(
|
||||
model=model,
|
||||
max_tokens_to_sample=max_tokens,
|
||||
stream=stream,
|
||||
prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}",
|
||||
)
|
||||
return completion.completion
|
||||
|
||||
def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
stream_response = self.anthropic.completions.create(
|
||||
model=model,
|
||||
prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}",
|
||||
max_tokens_to_sample=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for completion in stream_response:
|
||||
yield completion.completion
|
||||
49
application/llm/docsgpt_provider.py
Normal file
49
application/llm/docsgpt_provider.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from application.llm.base import BaseLLM
|
||||
import json
|
||||
import requests
|
||||
|
||||
class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.endpoint = "https://llm.docsgpt.co.uk"
|
||||
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
response = requests.post(
|
||||
f"{self.endpoint}/answer",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 30
|
||||
}
|
||||
)
|
||||
response_clean = response.json()['a'].split("###")[0]
|
||||
|
||||
return response_clean
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
# send prompt to endpoint /stream
|
||||
response = requests.post(
|
||||
f"{self.endpoint}/stream",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 256
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
#data = json.loads(line)
|
||||
data_str = line.decode('utf-8')
|
||||
if data_str.startswith("data: "):
|
||||
data = json.loads(data_str[6:])
|
||||
yield data['a']
|
||||
|
||||
@@ -2,6 +2,8 @@ from application.llm.openai import OpenAILLM, AzureOpenAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.huggingface import HuggingFaceLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +13,9 @@ class LLMCreator:
|
||||
'azure_openai': AzureOpenAILLM,
|
||||
'sagemaker': SagemakerAPILLM,
|
||||
'huggingface': HuggingFaceLLM,
|
||||
'llama.cpp': LlamaCpp
|
||||
'llama.cpp': LlamaCpp,
|
||||
'anthropic': AnthropicLLM,
|
||||
'docsgpt': DocsGPTAPILLM
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -5,40 +5,38 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key):
|
||||
global openai
|
||||
import openai
|
||||
openai.api_key = api_key
|
||||
self.api_key = api_key # Save the API key to be used later
|
||||
from openai import OpenAI
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
)
|
||||
self.api_key = api_key
|
||||
|
||||
def _get_openai(self):
|
||||
# Import openai when needed
|
||||
import openai
|
||||
# Set the API key every time you import openai
|
||||
openai.api_key = self.api_key
|
||||
|
||||
return openai
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
engine=engine,
|
||||
response = self.client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
**kwargs
|
||||
)
|
||||
**kwargs)
|
||||
|
||||
return response["choices"][0]["message"]["content"]
|
||||
return response.choices[0].message.content
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
engine=engine,
|
||||
response = self.client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
**kwargs
|
||||
)
|
||||
**kwargs)
|
||||
|
||||
for line in response:
|
||||
if "content" in line["choices"][0]["delta"]:
|
||||
yield line["choices"][0]["delta"]["content"]
|
||||
# import sys
|
||||
# print(line.choices[0].delta.content, file=sys.stderr)
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
@@ -48,10 +46,15 @@ class AzureOpenAILLM(OpenAILLM):
|
||||
self.api_base = settings.OPENAI_API_BASE,
|
||||
self.api_version = settings.OPENAI_API_VERSION,
|
||||
self.deployment_name = settings.AZURE_DEPLOYMENT_NAME,
|
||||
from openai import AzureOpenAI
|
||||
self.client = AzureOpenAI(
|
||||
api_key=openai_api_key,
|
||||
api_version=settings.OPENAI_API_VERSION,
|
||||
api_base=settings.OPENAI_API_BASE,
|
||||
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
|
||||
)
|
||||
|
||||
def _get_openai(self):
|
||||
openai = super()._get_openai()
|
||||
openai.api_base = self.api_base
|
||||
openai.api_version = self.api_version
|
||||
openai.api_type = "azure"
|
||||
|
||||
return openai
|
||||
|
||||
@@ -62,7 +62,6 @@ class SimpleDirectoryReader(BaseReader):
|
||||
file_extractor: Optional[Dict[str, BaseParser]] = None,
|
||||
num_files_limit: Optional[int] = None,
|
||||
file_metadata: Optional[Callable[[str], Dict]] = None,
|
||||
chunk_size_max: int = 2048,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
super().__init__()
|
||||
|
||||
51
application/parser/file/openapi3_parser.py
Normal file
51
application/parser/file/openapi3_parser.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openapi_parser import parse
|
||||
|
||||
try:
|
||||
from application.parser.file.base_parser import BaseParser
|
||||
except ModuleNotFoundError:
|
||||
from base_parser import BaseParser
|
||||
|
||||
|
||||
class OpenAPI3Parser(BaseParser):
|
||||
def init_parser(self) -> None:
|
||||
return super().init_parser()
|
||||
|
||||
def get_base_urls(self, urls):
|
||||
base_urls = []
|
||||
for i in urls:
|
||||
parsed_url = urlparse(i)
|
||||
base_url = parsed_url.scheme + "://" + parsed_url.netloc
|
||||
if base_url not in base_urls:
|
||||
base_urls.append(base_url)
|
||||
return base_urls
|
||||
|
||||
def get_info_from_paths(self, path):
|
||||
info = ""
|
||||
if path.operations:
|
||||
for operation in path.operations:
|
||||
info += (
|
||||
f"\n{operation.method.value}="
|
||||
f"{operation.responses[0].description}"
|
||||
)
|
||||
return info
|
||||
|
||||
def parse_file(self, file_path):
|
||||
data = parse(file_path)
|
||||
results = ""
|
||||
base_urls = self.get_base_urls(link.url for link in data.servers)
|
||||
base_urls = ",".join([base_url for base_url in base_urls])
|
||||
results += f"Base URL:{base_urls}\n"
|
||||
i = 1
|
||||
for path in data.paths:
|
||||
info = self.get_info_from_paths(path)
|
||||
results += (
|
||||
f"Path{i}: {path.url}\n"
|
||||
f"description: {path.description}\n"
|
||||
f"parameters: {path.parameters}\nmethods: {info}\n"
|
||||
)
|
||||
i += 1
|
||||
with open("results.txt", "w") as f:
|
||||
f.write(results)
|
||||
return results
|
||||
@@ -6,9 +6,9 @@ from application.core.settings import settings
|
||||
from retry import retry
|
||||
|
||||
|
||||
# from langchain.embeddings import HuggingFaceEmbeddings
|
||||
# from langchain.embeddings import HuggingFaceInstructEmbeddings
|
||||
# from langchain.embeddings import CohereEmbeddings
|
||||
# from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
# from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
||||
# from langchain_community.embeddings import CohereEmbeddings
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
||||
|
||||
9
application/prompts/chat_combine_default.txt
Normal file
9
application/prompts/chat_combine_default.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
You are a helpful AI assistant, DocsGPT, specializing in document assistance, designed to offer detailed and informative responses.
|
||||
If appropriate, your answers can include code examples, formatted as follows:
|
||||
```(language)
|
||||
(code)
|
||||
```
|
||||
You effectively utilize chat history, ensuring relevant and tailored responses.
|
||||
If a question doesn't align with your context, you provide friendly and helpful replies.
|
||||
----------------
|
||||
{summaries}
|
||||
13
application/prompts/chat_combine_strict.txt
Normal file
13
application/prompts/chat_combine_strict.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
You are an AI Assistant, DocsGPT, adept at offering document assistance.
|
||||
Your expertise lies in providing answer on top of provided context.
|
||||
You can leverage the chat history if needed.
|
||||
Answer the question based on the context below.
|
||||
Keep the answer concise. Respond "Irrelevant context" if not sure about the answer.
|
||||
If question is not related to the context, respond "Irrelevant context".
|
||||
When using code examples, use the following format:
|
||||
```(language)
|
||||
(code)
|
||||
```
|
||||
----------------
|
||||
Context:
|
||||
{summaries}
|
||||
@@ -1,25 +0,0 @@
|
||||
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible.
|
||||
|
||||
QUESTION: How to merge tables in pandas?
|
||||
=========
|
||||
Content: pandas provides various facilities for easily combining together Series or DataFrame with various kinds of set logic for the indexes and relational algebra functionality in the case of join / merge-type operations.
|
||||
Source: 28-pl
|
||||
Content: pandas provides a single function, merge(), as the entry point for all standard database join operations between DataFrame or named Series objects: \n\npandas.merge(left, right, how='inner', on=None, left_on=None, right_on=None, left_index=False, right_index=False, sort=False, suffixes=('_x', '_y'), copy=True, indicator=False, validate=None)
|
||||
Source: 30-pl
|
||||
=========
|
||||
FINAL ANSWER: To merge two tables in pandas, you can use the pd.merge() function. The basic syntax is: \n\npd.merge(left, right, on, how) \n\nwhere left and right are the two tables to merge, on is the column to merge on, and how is the type of merge to perform. \n\nFor example, to merge the two tables df1 and df2 on the column 'id', you can use: \n\npd.merge(df1, df2, on='id', how='inner')
|
||||
SOURCES: 28-pl 30-pl
|
||||
|
||||
QUESTION: How are you?
|
||||
=========
|
||||
CONTENT:
|
||||
SOURCE:
|
||||
=========
|
||||
FINAL ANSWER: I am fine, thank you. How are you?
|
||||
SOURCES:
|
||||
|
||||
QUESTION: {{ question }}
|
||||
=========
|
||||
{{ summaries }}
|
||||
=========
|
||||
FINAL ANSWER:
|
||||
@@ -1,33 +0,0 @@
|
||||
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible.
|
||||
|
||||
QUESTION: How to merge tables in pandas?
|
||||
=========
|
||||
Content: pandas provides various facilities for easily combining together Series or DataFrame with various kinds of set logic for the indexes and relational algebra functionality in the case of join / merge-type operations.
|
||||
Source: 28-pl
|
||||
Content: pandas provides a single function, merge(), as the entry point for all standard database join operations between DataFrame or named Series objects: \n\npandas.merge(left, right, how='inner', on=None, left_on=None, right_on=None, left_index=False, right_index=False, sort=False, suffixes=('_x', '_y'), copy=True, indicator=False, validate=None)
|
||||
Source: 30-pl
|
||||
=========
|
||||
FINAL ANSWER: To merge two tables in pandas, you can use the pd.merge() function. The basic syntax is: \n\npd.merge(left, right, on, how) \n\nwhere left and right are the two tables to merge, on is the column to merge on, and how is the type of merge to perform. \n\nFor example, to merge the two tables df1 and df2 on the column 'id', you can use: \n\npd.merge(df1, df2, on='id', how='inner')
|
||||
SOURCES: 28-pl 30-pl
|
||||
|
||||
QUESTION: How are you?
|
||||
=========
|
||||
CONTENT:
|
||||
SOURCE:
|
||||
=========
|
||||
FINAL ANSWER: I am fine, thank you. How are you?
|
||||
SOURCES:
|
||||
|
||||
QUESTION: {{ historyquestion }}
|
||||
=========
|
||||
CONTENT:
|
||||
SOURCE:
|
||||
=========
|
||||
FINAL ANSWER: {{ historyanswer }}
|
||||
SOURCES:
|
||||
|
||||
QUESTION: {{ question }}
|
||||
=========
|
||||
{{ summaries }}
|
||||
=========
|
||||
FINAL ANSWER:
|
||||
@@ -1,4 +0,0 @@
|
||||
Use the following portion of a long document to see if any of the text is relevant to answer the question.
|
||||
{{ context }}
|
||||
Question: {{ question }}
|
||||
Provide all relevant text to the question verbatim. Summarize if needed. If nothing relevant return "-".
|
||||
@@ -1,106 +1,33 @@
|
||||
aiodns==3.0.0
|
||||
aiohttp==3.8.5
|
||||
aiohttp-retry==2.8.3
|
||||
aiosignal==1.3.1
|
||||
aleph-alpha-client==2.16.1
|
||||
amqp==5.1.1
|
||||
async-timeout==4.0.2
|
||||
attrs==22.2.0
|
||||
billiard==3.6.4.0
|
||||
blobfile==2.0.1
|
||||
boto3==1.28.20
|
||||
celery==5.2.7
|
||||
cffi==1.15.1
|
||||
charset-normalizer==3.1.0
|
||||
click==8.1.3
|
||||
click-didyoumean==0.3.0
|
||||
click-plugins==1.1.1
|
||||
click-repl==0.2.0
|
||||
cryptography==41.0.4
|
||||
dataclasses-json==0.5.7
|
||||
decorator==5.1.1
|
||||
dill==0.3.6
|
||||
dnspython==2.3.0
|
||||
ecdsa==0.18.0
|
||||
elasticsearch==8.9.0
|
||||
entrypoints==0.4
|
||||
faiss-cpu==1.7.3
|
||||
filelock==3.9.0
|
||||
Flask==2.2.5
|
||||
Flask-Cors==3.0.10
|
||||
frozenlist==1.3.3
|
||||
geojson==2.5.0
|
||||
gunicorn==20.1.0
|
||||
greenlet==2.0.2
|
||||
gpt4all==0.1.7
|
||||
huggingface-hub==0.15.1
|
||||
humbug==0.3.2
|
||||
idna==3.4
|
||||
itsdangerous==2.1.2
|
||||
Jinja2==3.1.2
|
||||
jmespath==1.0.1
|
||||
joblib==1.2.0
|
||||
kombu==5.2.4
|
||||
langchain==0.0.308
|
||||
loguru==0.6.0
|
||||
lxml==4.9.2
|
||||
MarkupSafe==2.1.2
|
||||
marshmallow==3.19.0
|
||||
marshmallow-enum==1.5.1
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.14
|
||||
mypy-extensions==1.0.0
|
||||
networkx==3.0
|
||||
npx
|
||||
anthropic==0.12.0
|
||||
boto3==1.34.6
|
||||
celery==5.3.6
|
||||
dataclasses_json==0.6.3
|
||||
docx2txt==0.8
|
||||
EbookLib==0.18
|
||||
elasticsearch==8.12.0
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
faiss-cpu==1.7.4
|
||||
Flask==3.0.1
|
||||
gunicorn==21.2.0
|
||||
html2text==2020.1.16
|
||||
javalang==0.13.0
|
||||
langchain==0.1.4
|
||||
langchain-openai==0.0.5
|
||||
nltk==3.8.1
|
||||
numcodecs==0.11.0
|
||||
numpy==1.24.2
|
||||
openai==0.27.8
|
||||
packaging==23.0
|
||||
pathos==0.3.0
|
||||
Pillow==10.0.1
|
||||
pox==0.3.2
|
||||
ppft==1.7.6.6
|
||||
prompt-toolkit==3.0.38
|
||||
py==1.11.0
|
||||
pyasn1==0.4.8
|
||||
pycares==4.3.0
|
||||
pycparser==2.21
|
||||
pycryptodomex==3.17
|
||||
pycryptodome==3.19.0
|
||||
pydantic==1.10.5
|
||||
PyJWT==2.6.0
|
||||
pymongo==4.3.3
|
||||
pyowm==3.3.0
|
||||
openapi3_parser==1.1.16
|
||||
pandas==2.2.0
|
||||
pydantic_settings==2.1.0
|
||||
pymongo==4.6.1
|
||||
PyPDF2==3.0.1
|
||||
PySocks==1.7.1
|
||||
pytest
|
||||
python-dateutil==2.8.2
|
||||
python-dotenv==1.0.0
|
||||
python-jose==3.3.0
|
||||
pytz==2022.7.1
|
||||
PyYAML==6.0
|
||||
redis==4.5.4
|
||||
regex==2022.10.31
|
||||
requests==2.31.0
|
||||
python-dotenv==1.0.1
|
||||
redis==5.0.1
|
||||
Requests==2.31.0
|
||||
retry==0.9.2
|
||||
rsa==4.9
|
||||
scikit-learn==1.2.2
|
||||
scipy==1.10.1
|
||||
sentencepiece
|
||||
six==1.16.0
|
||||
SQLAlchemy==1.4.46
|
||||
sympy==1.11.1
|
||||
tenacity==8.2.2
|
||||
threadpoolctl==3.1.0
|
||||
tiktoken
|
||||
tqdm==4.65.0
|
||||
transformers==4.30.0
|
||||
typer==0.7.0
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.5.0
|
||||
urllib3==1.26.17
|
||||
vine==5.0.0
|
||||
wcwidth==0.2.6
|
||||
yarl==1.8.2
|
||||
sentence-transformers
|
||||
tiktoken==0.5.2
|
||||
torch==2.1.2
|
||||
tqdm==4.66.1
|
||||
transformers==4.36.2
|
||||
unstructured==0.12.2
|
||||
Werkzeug==3.0.1
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from langchain.embeddings import (
|
||||
OpenAIEmbeddings,
|
||||
from langchain_community.embeddings import (
|
||||
HuggingFaceEmbeddings,
|
||||
CohereEmbeddings,
|
||||
HuggingFaceInstructEmbeddings,
|
||||
)
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from application.core.settings import settings
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
@@ -44,6 +44,11 @@ class BaseVectorStore(ABC):
|
||||
embedding_instance = embeddings_factory[embeddings_name](
|
||||
cohere_api_key=embeddings_key
|
||||
)
|
||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
embedding_instance = embeddings_factory[embeddings_name](
|
||||
#model_name="./model/all-mpnet-base-v2",
|
||||
model_kwargs={"device": "cpu"},
|
||||
)
|
||||
else:
|
||||
embedding_instance = embeddings_factory[embeddings_name]()
|
||||
|
||||
|
||||
8
application/vectorstore/document_class.py
Normal file
8
application/vectorstore/document_class.py
Normal file
@@ -0,0 +1,8 @@
|
||||
class Document(str):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
def __new__(cls, page_content: str, metadata: dict):
|
||||
instance = super().__new__(cls, page_content)
|
||||
instance.page_content = page_content
|
||||
instance.metadata = metadata
|
||||
return instance
|
||||
@@ -1,16 +1,8 @@
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.document_class import Document
|
||||
import elasticsearch
|
||||
|
||||
class Document(str):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
def __new__(cls, page_content: str, metadata: dict):
|
||||
instance = super().__new__(cls, page_content)
|
||||
instance.page_content = page_content
|
||||
instance.metadata = metadata
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from langchain.vectorstores import FAISS
|
||||
from application.core.settings import settings
|
||||
|
||||
class FaissStore(BaseVectorStore):
|
||||
@@ -7,20 +7,40 @@ class FaissStore(BaseVectorStore):
|
||||
def __init__(self, path, embeddings_key, docs_init=None):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
if docs_init:
|
||||
self.docsearch = FAISS.from_documents(
|
||||
docs_init, self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
docs_init, embeddings
|
||||
)
|
||||
else:
|
||||
self.docsearch = FAISS.load_local(
|
||||
self.path, self._get_embeddings(settings.EMBEDDINGS_NAME, settings.EMBEDDINGS_KEY)
|
||||
self.path, embeddings
|
||||
)
|
||||
self.assert_embedding_dimensions(embeddings)
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
return self.docsearch.similarity_search(*args, **kwargs)
|
||||
|
||||
def add_texts(self, *args, **kwargs):
|
||||
return self.docsearch.add_texts(*args, **kwargs)
|
||||
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
return self.docsearch.save_local(*args, **kwargs)
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
return self.docsearch.delete(*args, **kwargs)
|
||||
|
||||
def assert_embedding_dimensions(self, embeddings):
|
||||
"""
|
||||
Check that the word embedding dimension of the docsearch index matches
|
||||
the dimension of the word embeddings used
|
||||
"""
|
||||
if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
try:
|
||||
word_embedding_dimension = embeddings.client[1].word_embedding_dimension
|
||||
except AttributeError as e:
|
||||
raise AttributeError("word_embedding_dimension not found in embeddings.client[1]") from e
|
||||
docsearch_index_dimension = self.docsearch.index.d
|
||||
if word_embedding_dimension != docsearch_index_dimension:
|
||||
raise ValueError(f"word_embedding_dimension ({word_embedding_dimension}) " +
|
||||
f"!= docsearch_index_word_embedding_dimension ({docsearch_index_dimension})")
|
||||
126
application/vectorstore/mongodb.py
Normal file
126
application/vectorstore/mongodb.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.document_class import Document
|
||||
|
||||
class MongoDBVectorStore(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
path: str = "",
|
||||
embeddings_key: str = "embeddings",
|
||||
collection: str = "documents",
|
||||
index_name: str = "vector_search_index",
|
||||
text_key: str = "text",
|
||||
embedding_key: str = "embedding",
|
||||
database: str = "docsgpt",
|
||||
):
|
||||
self._index_name = index_name
|
||||
self._text_key = text_key
|
||||
self._embedding_key = embedding_key
|
||||
self._embeddings_key = embeddings_key
|
||||
self._mongo_uri = settings.MONGO_URI
|
||||
self._path = path.replace("application/indexes/", "").rstrip("/")
|
||||
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
|
||||
try:
|
||||
import pymongo
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import pymongo python package. "
|
||||
"Please install it with `pip install pymongo`."
|
||||
)
|
||||
|
||||
self._client = pymongo.MongoClient(self._mongo_uri)
|
||||
self._database = self._client[database]
|
||||
self._collection = self._database[collection]
|
||||
|
||||
|
||||
def search(self, question, k=2, *args, **kwargs):
|
||||
query_vector = self._embedding.embed_query(question)
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"queryVector": query_vector,
|
||||
"path": self._embedding_key,
|
||||
"limit": k,
|
||||
"numCandidates": k * 10,
|
||||
"index": self._index_name,
|
||||
"filter": {
|
||||
"store": {"$eq": self._path}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
cursor = self._collection.aggregate(pipeline)
|
||||
|
||||
results = []
|
||||
for doc in cursor:
|
||||
text = doc[self._text_key]
|
||||
doc.pop("_id")
|
||||
doc.pop(self._text_key)
|
||||
doc.pop(self._embedding_key)
|
||||
metadata = doc
|
||||
results.append(Document(text, metadata))
|
||||
return results
|
||||
|
||||
def _insert_texts(self, texts, metadatas):
|
||||
if not texts:
|
||||
return []
|
||||
embeddings = self._embedding.embed_documents(texts)
|
||||
to_insert = [
|
||||
{self._text_key: t, self._embedding_key: embedding, **m}
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in MongoDB Atlas
|
||||
insert_result = self._collection.insert_many(to_insert)
|
||||
return insert_result.inserted_ids
|
||||
|
||||
def add_texts(self,
|
||||
texts,
|
||||
metadatas = None,
|
||||
ids = None,
|
||||
refresh_indices = True,
|
||||
create_index_if_not_exists = True,
|
||||
bulk_kwargs = None,
|
||||
**kwargs,):
|
||||
|
||||
|
||||
#dims = self._embedding.client[1].word_embedding_dimension
|
||||
# # check if index exists
|
||||
# if create_index_if_not_exists:
|
||||
# # check if index exists
|
||||
# info = self._collection.index_information()
|
||||
# if self._index_name not in info:
|
||||
# index_mongo = {
|
||||
# "fields": [{
|
||||
# "type": "vector",
|
||||
# "path": self._embedding_key,
|
||||
# "numDimensions": dims,
|
||||
# "similarity": "cosine",
|
||||
# },
|
||||
# {
|
||||
# "type": "filter",
|
||||
# "path": "store"
|
||||
# }]
|
||||
# }
|
||||
# self._collection.create_index(self._index_name, index_mongo)
|
||||
|
||||
batch_size = 100
|
||||
_metadatas = metadatas or ({} for _ in texts)
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
result_ids = []
|
||||
for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||
texts_batch.append(text)
|
||||
metadatas_batch.append(metadata)
|
||||
if (i + 1) % batch_size == 0:
|
||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
if texts_batch:
|
||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||
return result_ids
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
self._collection.delete_many({"store": self._path})
|
||||
@@ -1,11 +1,13 @@
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
from application.vectorstore.mongodb import MongoDBVectorStore
|
||||
|
||||
|
||||
class VectorCreator:
|
||||
vectorstores = {
|
||||
'faiss': FaissStore,
|
||||
'elasticsearch':ElasticsearchStore
|
||||
'elasticsearch':ElasticsearchStore,
|
||||
'mongodb': MongoDBVectorStore,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -21,17 +21,34 @@ except FileExistsError:
|
||||
pass
|
||||
|
||||
|
||||
# Define a function to extract metadata from a given filename.
|
||||
def metadata_from_filename(title):
|
||||
store = '/'.join(title.split('/')[1:3])
|
||||
return {'title': title, 'store': store}
|
||||
|
||||
|
||||
# Define a function to generate a random string of a given length.
|
||||
def generate_random_string(length):
|
||||
return ''.join([string.ascii_letters[i % 52] for i in range(length)])
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Define the main function for ingesting and processing documents.
|
||||
def ingest_worker(self, directory, formats, name_job, filename, user):
|
||||
"""
|
||||
Ingest and process documents.
|
||||
|
||||
Args:
|
||||
self: Reference to the instance of the task.
|
||||
directory (str): Specifies the directory for ingesting ('inputs' or 'temp').
|
||||
formats (list of str): List of file extensions to consider for ingestion (e.g., [".rst", ".md"]).
|
||||
name_job (str): Name of the job for this ingestion task.
|
||||
filename (str): Name of the file to be ingested.
|
||||
user (str): Identifier for the user initiating the ingestion.
|
||||
|
||||
Returns:
|
||||
dict: Information about the completed ingestion task, including input parameters and a "limited" flag.
|
||||
"""
|
||||
# directory = 'inputs' or 'temp'
|
||||
# formats = [".rst", ".md"]
|
||||
input_files = None
|
||||
|
||||
Reference in New Issue
Block a user