From 32a019c0d6e3e1b6fe969e12367b1bbec27a7160 Mon Sep 17 00:00:00 2001
From: Alex
Date: Wed, 27 Sep 2023 22:39:48 +0100
Subject: [PATCH 1/2] Update requirements.txt
---
application/requirements.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/application/requirements.txt b/application/requirements.txt
index d978cb41..57520a43 100644
--- a/application/requirements.txt
+++ b/application/requirements.txt
@@ -67,6 +67,7 @@ pyasn1==0.4.8
pycares==4.3.0
pycparser==2.21
pycryptodomex==3.17
+pycryptodome==3.19.0
pydantic==1.10.5
PyJWT==2.6.0
pymongo==4.3.3
From e4be38b9f7533d55884acf8a010000a59a7370f1 Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 29 Sep 2023 01:09:01 +0100
Subject: [PATCH 2/2] sagemaker + llm creator class
---
application/api/answer/routes.py | 26 +++++---------------------
application/core/settings.py | 2 +-
application/llm/llm_creator.py | 20 ++++++++++++++++++++
application/llm/openai.py | 7 ++++---
application/llm/sagemaker.py | 27 +++++++++++++++++++++++++++
5 files changed, 57 insertions(+), 25 deletions(-)
create mode 100644 application/llm/llm_creator.py
create mode 100644 application/llm/sagemaker.py
diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py
index ae9ef71f..86943ff1 100644
--- a/application/api/answer/routes.py
+++ b/application/api/answer/routes.py
@@ -13,7 +13,7 @@ from transformers import GPT2TokenizerFast
from application.core.settings import settings
-from application.llm.openai import OpenAILLM, AzureOpenAILLM
+from application.llm.llm_creator import LLMCreator
from application.vectorstore.faiss import FaissStore
from application.error import bad_request
@@ -128,16 +128,8 @@ def is_azure_configured():
def complete_stream(question, docsearch, chat_history, api_key, conversation_id):
- if is_azure_configured():
- llm = AzureOpenAILLM(
- openai_api_key=api_key,
- openai_api_base=settings.OPENAI_API_BASE,
- openai_api_version=settings.OPENAI_API_VERSION,
- deployment_name=settings.AZURE_DEPLOYMENT_NAME,
- )
- else:
- logger.debug("plain OpenAI")
- llm = OpenAILLM(api_key=api_key)
+ llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
+
docs = docsearch.search(question, k=2)
# join all page_content together with a newline
@@ -270,16 +262,8 @@ def api_answer():
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
docsearch = FaissStore(vectorstore, embeddings_key)
- if is_azure_configured():
- llm = AzureOpenAILLM(
- openai_api_key=api_key,
- openai_api_base=settings.OPENAI_API_BASE,
- openai_api_version=settings.OPENAI_API_VERSION,
- deployment_name=settings.AZURE_DEPLOYMENT_NAME,
- )
- else:
- logger.debug("plain OpenAI")
- llm = OpenAILLM(api_key=api_key)
+
+ llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
diff --git a/application/core/settings.py b/application/core/settings.py
index d127c293..1479beb3 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -4,7 +4,7 @@ from pydantic import BaseSettings
class Settings(BaseSettings):
- LLM_NAME: str = "openai_chat"
+ LLM_NAME: str = "openai"
EMBEDDINGS_NAME: str = "openai_text-embedding-ada-002"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py
new file mode 100644
index 00000000..a7ffc0f6
--- /dev/null
+++ b/application/llm/llm_creator.py
@@ -0,0 +1,20 @@
+from application.llm.openai import OpenAILLM, AzureOpenAILLM
+from application.llm.sagemaker import SagemakerAPILLM
+from application.llm.huggingface import HuggingFaceLLM
+
+
+
+class LLMCreator:
+ llms = {
+ 'openai': OpenAILLM,
+ 'azure_openai': AzureOpenAILLM,
+ 'sagemaker': SagemakerAPILLM,
+ 'huggingface': HuggingFaceLLM
+ }
+
+ @classmethod
+ def create_llm(cls, type, *args, **kwargs):
+ llm_class = cls.llms.get(type.lower())
+ if not llm_class:
+ raise ValueError(f"No LLM class found for type {type}")
+ return llm_class(*args, **kwargs)
\ No newline at end of file
diff --git a/application/llm/openai.py b/application/llm/openai.py
index 23e5fab0..34d56854 100644
--- a/application/llm/openai.py
+++ b/application/llm/openai.py
@@ -1,4 +1,5 @@
from application.llm.base import BaseLLM
+from application.core.settings import settings
class OpenAILLM(BaseLLM):
@@ -44,9 +45,9 @@ class AzureOpenAILLM(OpenAILLM):
def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name):
super().__init__(openai_api_key)
- self.api_base = openai_api_base
- self.api_version = openai_api_version
- self.deployment_name = deployment_name
+ self.api_base = settings.OPENAI_API_BASE,
+ self.api_version = settings.OPENAI_API_VERSION,
+ self.deployment_name = settings.AZURE_DEPLOYMENT_NAME,
def _get_openai(self):
openai = super()._get_openai()
diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py
new file mode 100644
index 00000000..9ef5d0af
--- /dev/null
+++ b/application/llm/sagemaker.py
@@ -0,0 +1,27 @@
+from application.llm.base import BaseLLM
+from application.core.settings import settings
+import requests
+import json
+
+class SagemakerAPILLM(BaseLLM):
+
+ def __init__(self, *args, **kwargs):
+ self.url = settings.SAGEMAKER_API_URL
+
+ def gen(self, model, engine, messages, stream=False, **kwargs):
+ context = messages[0]['content']
+ user_question = messages[-1]['content']
+ prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
+
+ response = requests.post(
+ url=self.url,
+ headers={
+ "Content-Type": "application/json; charset=utf-8",
+ },
+ data=json.dumps({"input": prompt})
+ )
+
+ return response.json()['answer']
+
+ def gen_stream(self, model, engine, messages, stream=True, **kwargs):
+ raise NotImplementedError("Sagemaker does not support streaming")
\ No newline at end of file