mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
fix: api_key capturing + pytest errors
This commit is contained in:
@@ -6,16 +6,25 @@ from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
|
||||
|
||||
class DuckDuckSearch(BaseRetriever):
|
||||
|
||||
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
|
||||
def __init__(
|
||||
self,
|
||||
question,
|
||||
source,
|
||||
chat_history,
|
||||
prompt,
|
||||
chunks=2,
|
||||
gpt_model="docsgpt",
|
||||
api_key=None,
|
||||
):
|
||||
self.question = question
|
||||
self.source = source
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.api_key = api_key
|
||||
|
||||
def _parse_lang_string(self, input_string):
|
||||
result = []
|
||||
@@ -30,12 +39,12 @@ class DuckDuckSearch(BaseRetriever):
|
||||
current_item = ""
|
||||
elif inside_brackets:
|
||||
current_item += char
|
||||
|
||||
|
||||
if inside_brackets:
|
||||
result.append(current_item)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
docs = []
|
||||
@@ -44,7 +53,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
|
||||
results = search.run(self.question)
|
||||
results = self._parse_lang_string(results)
|
||||
|
||||
|
||||
docs = []
|
||||
for i in results:
|
||||
try:
|
||||
@@ -56,12 +65,12 @@ class DuckDuckSearch(BaseRetriever):
|
||||
pass
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def gen(self):
|
||||
docs = self._get_data()
|
||||
|
||||
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc["text"] for doc in docs])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
@@ -75,20 +84,27 @@ class DuckDuckSearch(BaseRetriever):
|
||||
self.chat_history.reverse()
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
||||
i["response"]
|
||||
)
|
||||
if (
|
||||
tokens_current_history + tokens_batch
|
||||
< settings.TOKENS_MAX_HISTORY
|
||||
):
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "system", "content": i["response"]})
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "system", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model,
|
||||
messages=messages_combine)
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user