fix: changes to llm classes according to base

This commit is contained in:
Siddhant Rai
2024-04-15 19:47:24 +05:30
parent c1c69ed22b
commit 60a670ce29
6 changed files with 15 additions and 6 deletions

View File

@@ -3,7 +3,7 @@ from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False):
def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs):
global hf
from langchain.llms import HuggingFacePipeline
@@ -33,6 +33,8 @@ class HuggingFaceLLM(BaseLLM):
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name)
super().__init__(*args, **kwargs)
self.api_key = api_key
pipe = pipeline(
"text-generation",
model=model,