sagemaker + llm creator class

This commit is contained in:
Alex
2023-09-29 01:09:01 +01:00
parent c1c54f4848
commit e4be38b9f7
5 changed files with 57 additions and 25 deletions

View File

@@ -13,7 +13,7 @@ from transformers import GPT2TokenizerFast
from application.core.settings import settings
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.llm_creator import LLMCreator
from application.vectorstore.faiss import FaissStore
from application.error import bad_request
@@ -128,16 +128,8 @@ def is_azure_configured():
def complete_stream(question, docsearch, chat_history, api_key, conversation_id):
if is_azure_configured():
llm = AzureOpenAILLM(
openai_api_key=api_key,
openai_api_base=settings.OPENAI_API_BASE,
openai_api_version=settings.OPENAI_API_VERSION,
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
)
else:
logger.debug("plain OpenAI")
llm = OpenAILLM(api_key=api_key)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
docs = docsearch.search(question, k=2)
# join all page_content together with a newline
@@ -270,16 +262,8 @@ def api_answer():
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
docsearch = FaissStore(vectorstore, embeddings_key)
if is_azure_configured():
llm = AzureOpenAILLM(
openai_api_key=api_key,
openai_api_base=settings.OPENAI_API_BASE,
openai_api_version=settings.OPENAI_API_VERSION,
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
)
else:
logger.debug("plain OpenAI")
llm = OpenAILLM(api_key=api_key)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)