mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
fix: brave and duckduckgo retrievers
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
import json
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
from langchain_community.tools import BraveSearch
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from langchain_community.tools import BraveSearch
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class BraveRetSearch(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
question,
|
||||
source,
|
||||
chat_history,
|
||||
prompt,
|
||||
@@ -19,7 +20,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = question
|
||||
self.question = ""
|
||||
self.source = source
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
@@ -93,7 +94,9 @@ class BraveRetSearch(BaseRetriever):
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self):
|
||||
def search(self, query: str = ""):
|
||||
if query:
|
||||
self.question = query
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class DuckDuckSearch(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
question,
|
||||
source,
|
||||
chat_history,
|
||||
prompt,
|
||||
@@ -19,7 +19,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = question
|
||||
self.question = ""
|
||||
self.source = source
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
@@ -38,41 +38,24 @@ class DuckDuckSearch(BaseRetriever):
|
||||
self.user_api_key = user_api_key
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _parse_lang_string(self, input_string):
|
||||
result = []
|
||||
current_item = ""
|
||||
inside_brackets = False
|
||||
for char in input_string:
|
||||
if char == "[":
|
||||
inside_brackets = True
|
||||
elif char == "]":
|
||||
inside_brackets = False
|
||||
result.append(current_item)
|
||||
current_item = ""
|
||||
elif inside_brackets:
|
||||
current_item += char
|
||||
|
||||
if inside_brackets:
|
||||
result.append(current_item)
|
||||
|
||||
return result
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
|
||||
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
|
||||
search = DuckDuckGoSearchResults(api_wrapper=wrapper, output_format="list")
|
||||
results = search.run(self.question)
|
||||
results = self._parse_lang_string(results)
|
||||
|
||||
docs = []
|
||||
for i in results:
|
||||
try:
|
||||
text = i.split("title:")[0]
|
||||
title = i.split("title:")[1].split("link:")[0]
|
||||
link = i.split("link:")[1]
|
||||
docs.append({"text": text, "title": title, "link": link})
|
||||
docs.append(
|
||||
{
|
||||
"text": i.get("snippet", "").strip(),
|
||||
"title": i.get("title", "").strip(),
|
||||
"link": i.get("link", "").strip(),
|
||||
}
|
||||
)
|
||||
except IndexError:
|
||||
pass
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
@@ -110,7 +93,9 @@ class DuckDuckSearch(BaseRetriever):
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self):
|
||||
def search(self, query: str = ""):
|
||||
if query:
|
||||
self.question = query
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
|
||||
@@ -471,6 +471,7 @@ function Upload({
|
||||
}, 3000);
|
||||
};
|
||||
xhr.open('POST', `${apiHost}/api/remote`);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
};
|
||||
const { getRootProps, getInputProps, isDragActive } = useDropzone({
|
||||
|
||||
Reference in New Issue
Block a user