mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
fix: api_key capturing + pytest errors
This commit is contained in:
@@ -15,7 +15,9 @@ class AnthropicLLM(BaseLLM):
|
||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
|
||||
def _raw_gen(self, model, messages, max_tokens=300, stream=False, **kwargs):
|
||||
def _raw_gen(
|
||||
self, baseself, model, messages, max_tokens=300, stream=False, **kwargs
|
||||
):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
@@ -30,7 +32,7 @@ class AnthropicLLM(BaseLLM):
|
||||
)
|
||||
return completion.completion
|
||||
|
||||
def _raw_gen_stream(self, model, messages, max_tokens=300, **kwargs):
|
||||
def _raw_gen_stream(self, baseself, model, messages, max_tokens=300, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
|
||||
@@ -5,7 +5,7 @@ import requests
|
||||
|
||||
class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key, *args, **kwargs):
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.endpoint = "https://llm.docsgpt.co.uk"
|
||||
|
||||
@@ -3,7 +3,9 @@ from application.llm.base import BaseLLM
|
||||
|
||||
class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs):
|
||||
def __init__(
|
||||
self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs
|
||||
):
|
||||
global hf
|
||||
|
||||
from langchain.llms import HuggingFacePipeline
|
||||
@@ -45,7 +47,7 @@ class HuggingFaceLLM(BaseLLM):
|
||||
)
|
||||
hf = HuggingFacePipeline(pipeline=pipe)
|
||||
|
||||
def _raw_gen(self, model, messages, stream=False, **kwargs):
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@@ -54,6 +56,6 @@ class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
return result.content
|
||||
|
||||
def _raw_gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
|
||||
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")
|
||||
|
||||
@@ -4,7 +4,7 @@ from application.core.settings import settings
|
||||
|
||||
class LlamaCpp(BaseLLM):
|
||||
|
||||
def __init__(self, api_key, llm_name=settings.MODEL_PATH, *args, **kwargs):
|
||||
def __init__(self, api_key=None, llm_name=settings.MODEL_PATH, *args, **kwargs):
|
||||
global llama
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
@@ -17,7 +17,7 @@ class LlamaCpp(BaseLLM):
|
||||
self.api_key = api_key
|
||||
llama = Llama(model_path=llm_name, n_ctx=2048)
|
||||
|
||||
def _raw_gen(self, model, messages, stream=False, **kwargs):
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@@ -29,7 +29,7 @@ class LlamaCpp(BaseLLM):
|
||||
|
||||
return result["choices"][0]["text"].split("### Answer \n")[-1]
|
||||
|
||||
def _raw_gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
@@ -4,7 +4,7 @@ from application.core.settings import settings
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key, *args, **kwargs):
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
global openai
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -22,6 +22,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
@@ -36,6 +37,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
def _raw_gen_stream(
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
|
||||
@@ -4,7 +4,7 @@ from application.core.settings import settings
|
||||
|
||||
class PremAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key, *args, **kwargs):
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
from premai import Prem
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -12,7 +12,7 @@ class PremAILLM(BaseLLM):
|
||||
self.api_key = api_key
|
||||
self.project_id = settings.PREMAI_PROJECT_ID
|
||||
|
||||
def _raw_gen(self, model, messages, stream=False, **kwargs):
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
project_id=self.project_id,
|
||||
@@ -23,7 +23,7 @@ class PremAILLM(BaseLLM):
|
||||
|
||||
return response.choices[0].message["content"]
|
||||
|
||||
def _raw_gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
project_id=self.project_id,
|
||||
|
||||
@@ -60,7 +60,7 @@ class LineIterator:
|
||||
|
||||
class SagemakerAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key, *args, **kwargs):
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
import boto3
|
||||
|
||||
runtime = boto3.client(
|
||||
@@ -75,7 +75,7 @@ class SagemakerAPILLM(BaseLLM):
|
||||
self.endpoint = settings.SAGEMAKER_ENDPOINT
|
||||
self.runtime = runtime
|
||||
|
||||
def _raw_gen(self, model, messages, stream=False, **kwargs):
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@@ -104,7 +104,7 @@ class SagemakerAPILLM(BaseLLM):
|
||||
print(result[0]["generated_text"], file=sys.stderr)
|
||||
return result[0]["generated_text"][len(prompt) :]
|
||||
|
||||
def _raw_gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
Reference in New Issue
Block a user