mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge branch 'main' of https://github.com/arc53/DocsGPT
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
import json
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
from langchain_community.tools import BraveSearch
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from langchain_community.tools import BraveSearch
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class BraveRetSearch(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
question,
|
||||
source,
|
||||
chat_history,
|
||||
prompt,
|
||||
@@ -19,7 +20,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = question
|
||||
self.question = ""
|
||||
self.source = source
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
@@ -93,7 +94,9 @@ class BraveRetSearch(BaseRetriever):
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self):
|
||||
def search(self, query: str = ""):
|
||||
if query:
|
||||
self.question = query
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class DuckDuckSearch(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
question,
|
||||
source,
|
||||
chat_history,
|
||||
prompt,
|
||||
@@ -19,7 +19,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = question
|
||||
self.question = ""
|
||||
self.source = source
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
@@ -38,41 +38,24 @@ class DuckDuckSearch(BaseRetriever):
|
||||
self.user_api_key = user_api_key
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _parse_lang_string(self, input_string):
|
||||
result = []
|
||||
current_item = ""
|
||||
inside_brackets = False
|
||||
for char in input_string:
|
||||
if char == "[":
|
||||
inside_brackets = True
|
||||
elif char == "]":
|
||||
inside_brackets = False
|
||||
result.append(current_item)
|
||||
current_item = ""
|
||||
elif inside_brackets:
|
||||
current_item += char
|
||||
|
||||
if inside_brackets:
|
||||
result.append(current_item)
|
||||
|
||||
return result
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
|
||||
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
|
||||
search = DuckDuckGoSearchResults(api_wrapper=wrapper, output_format="list")
|
||||
results = search.run(self.question)
|
||||
results = self._parse_lang_string(results)
|
||||
|
||||
docs = []
|
||||
for i in results:
|
||||
try:
|
||||
text = i.split("title:")[0]
|
||||
title = i.split("title:")[1].split("link:")[0]
|
||||
link = i.split("link:")[1]
|
||||
docs.append({"text": text, "title": title, "link": link})
|
||||
docs.append(
|
||||
{
|
||||
"text": i.get("snippet", "").strip(),
|
||||
"title": i.get("title", "").strip(),
|
||||
"link": i.get("link", "").strip(),
|
||||
}
|
||||
)
|
||||
except IndexError:
|
||||
pass
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
@@ -110,7 +93,9 @@ class DuckDuckSearch(BaseRetriever):
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self):
|
||||
def search(self, query: str = ""):
|
||||
if query:
|
||||
self.question = query
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
|
||||
@@ -471,6 +471,7 @@ function Upload({
|
||||
}, 3000);
|
||||
};
|
||||
xhr.open('POST', `${apiHost}/api/remote`);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
};
|
||||
const { getRootProps, getInputProps, isDragActive } = useDropzone({
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""
|
||||
Tests regarding the vector store class, including checking
|
||||
compatibility between different transformers and local vector
|
||||
stores (index.faiss)
|
||||
"""
|
||||
import pytest
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
from application.core.settings import settings
|
||||
|
||||
def test_init_local_faiss_store_huggingface():
|
||||
"""
|
||||
Test that asserts that initializing a FaissStore with
|
||||
the huggingface sentence transformer below together with the
|
||||
index.faiss file in the application/ folder results in a
|
||||
dimension mismatch error.
|
||||
"""
|
||||
import os
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.docstore.document import Document
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
# Ensure application directory exists
|
||||
index_path = os.path.join("application")
|
||||
os.makedirs(index_path, exist_ok=True)
|
||||
|
||||
# Create an index.faiss with a different embeddings dimension
|
||||
# Use a different embedding model with a smaller dimension
|
||||
other_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" # Dimension 384
|
||||
other_embeddings = HuggingFaceEmbeddings(model_name=other_embedding_model)
|
||||
# Create some dummy documents
|
||||
docs = [Document(page_content="Test document")]
|
||||
# Create index using the other embeddings
|
||||
other_docsearch = FAISS.from_documents(docs, other_embeddings)
|
||||
# Save index to application/
|
||||
other_docsearch.save_local(index_path)
|
||||
|
||||
# Now set the EMBEDDINGS_NAME to the one with a different dimension
|
||||
settings.EMBEDDINGS_NAME = "huggingface_sentence-transformers/all-mpnet-base-v2" # Dimension 768
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
FaissStore("", None)
|
||||
assert "Embedding dimension mismatch" in str(exc_info.value)
|
||||
Reference in New Issue
Block a user