sagemaker + llm creator class

This commit is contained in:
Alex
2023-09-29 01:09:01 +01:00
parent c1c54f4848
commit e4be38b9f7
5 changed files with 57 additions and 25 deletions

View File

@@ -0,0 +1,20 @@
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.huggingface import HuggingFaceLLM
class LLMCreator:
llms = {
'openai': OpenAILLM,
'azure_openai': AzureOpenAILLM,
'sagemaker': SagemakerAPILLM,
'huggingface': HuggingFaceLLM
}
@classmethod
def create_llm(cls, type, *args, **kwargs):
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
return llm_class(*args, **kwargs)

View File

@@ -1,4 +1,5 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
class OpenAILLM(BaseLLM):
@@ -44,9 +45,9 @@ class AzureOpenAILLM(OpenAILLM):
def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name):
super().__init__(openai_api_key)
self.api_base = openai_api_base
self.api_version = openai_api_version
self.deployment_name = deployment_name
self.api_base = settings.OPENAI_API_BASE,
self.api_version = settings.OPENAI_API_VERSION,
self.deployment_name = settings.AZURE_DEPLOYMENT_NAME,
def _get_openai(self):
openai = super()._get_openai()

View File

@@ -0,0 +1,27 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
import requests
import json
class SagemakerAPILLM(BaseLLM):
def __init__(self, *args, **kwargs):
self.url = settings.SAGEMAKER_API_URL
def gen(self, model, engine, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
response = requests.post(
url=self.url,
headers={
"Content-Type": "application/json; charset=utf-8",
},
data=json.dumps({"input": prompt})
)
return response.json()['answer']
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
raise NotImplementedError("Sagemaker does not support streaming")