mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
fix: api_key capturing + pytest errors
This commit is contained in:
@@ -3,7 +3,9 @@ from application.llm.base import BaseLLM
|
||||
|
||||
class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs):
|
||||
def __init__(
|
||||
self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs
|
||||
):
|
||||
global hf
|
||||
|
||||
from langchain.llms import HuggingFacePipeline
|
||||
@@ -45,7 +47,7 @@ class HuggingFaceLLM(BaseLLM):
|
||||
)
|
||||
hf = HuggingFacePipeline(pipeline=pipe)
|
||||
|
||||
def _raw_gen(self, model, messages, stream=False, **kwargs):
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@@ -54,6 +56,6 @@ class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
return result.content
|
||||
|
||||
def _raw_gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
|
||||
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")
|
||||
|
||||
Reference in New Issue
Block a user