diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 35b95174..9a22db84 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -40,6 +40,8 @@ if settings.LLM_NAME == "openai": gpt_model = "gpt-3.5-turbo" elif settings.LLM_NAME == "anthropic": gpt_model = "claude-2" +elif settings.LLM_NAME == "groq": + gpt_model = "llama3-8b-8192" if settings.MODEL_NAME: # in case there is particular model name configured gpt_model = settings.MODEL_NAME diff --git a/application/llm/groq.py b/application/llm/groq.py new file mode 100644 index 00000000..b5731a90 --- /dev/null +++ b/application/llm/groq.py @@ -0,0 +1,45 @@ +from application.llm.base import BaseLLM + + + +class GroqLLM(BaseLLM): + + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): + from openai import OpenAI + + super().__init__(*args, **kwargs) + self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1") + self.api_key = api_key + self.user_api_key = user_api_key + + def _raw_gen( + self, + baseself, + model, + messages, + stream=False, + **kwargs + ): + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + + return response.choices[0].message.content + + def _raw_gen_stream( + self, + baseself, + model, + messages, + stream=True, + **kwargs + ): + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + + for line in response: + # import sys + # print(line.choices[0].delta.content, file=sys.stderr) + if line.choices[0].delta.content is not None: + yield line.choices[0].delta.content diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index 7960778b..6a19de10 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -1,3 +1,4 @@ +from application.llm.groq import GroqLLM from application.llm.openai import OpenAILLM, AzureOpenAILLM from application.llm.sagemaker import SagemakerAPILLM from application.llm.huggingface import HuggingFaceLLM @@ -17,6 +18,7 @@ class LLMCreator: "anthropic": AnthropicLLM, "docsgpt": DocsGPTAPILLM, "premai": PremAILLM, + "groq": GroqLLM } @classmethod diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 543699ed..3d956e2e 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -59,7 +59,12 @@ const ConversationBubble = forwardRef< className={`flex flex-row-reverse self-end flex-wrap ${className}`} > -
+
{message}