fix: api_key capturing + pytest errors

This commit is contained in:
Siddhant Rai
2024-04-15 22:32:24 +05:30
parent 60a670ce29
commit 77991896b4
11 changed files with 276 additions and 145 deletions

View File

@@ -3,7 +3,9 @@ from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs):
def __init__(
self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs
):
global hf
from langchain.llms import HuggingFacePipeline
@@ -45,7 +47,7 @@ class HuggingFaceLLM(BaseLLM):
)
hf = HuggingFacePipeline(pipeline=pipe)
def _raw_gen(self, model, messages, stream=False, **kwargs):
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@@ -54,6 +56,6 @@ class HuggingFaceLLM(BaseLLM):
return result.content
def _raw_gen_stream(self, model, messages, stream=True, **kwargs):
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")