From ce56a414e039c68d0ee32f8ae591ebede05e21b5 Mon Sep 17 00:00:00 2001
From: Alex
Date: Tue, 25 Jun 2024 14:37:00 +0100
Subject: [PATCH 1/2] fix: use singleton
---
application/llm/llama_cpp.py | 41 ++++++++++++++++--------------------
1 file changed, 18 insertions(+), 23 deletions(-)
diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py
index 25a2f0c9..a93fa686 100644
--- a/application/llm/llama_cpp.py
+++ b/application/llm/llama_cpp.py
@@ -1,9 +1,22 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
+class LlamaSingleton:
+ _instances = {}
+
+ @classmethod
+ def get_instance(cls, llm_name):
+ if llm_name not in cls._instances:
+ try:
+ from llama_cpp import Llama
+ except ImportError:
+ raise ImportError(
+ "Please install llama_cpp using pip install llama-cpp-python"
+ )
+ cls._instances[llm_name] = Llama(model_path=llm_name, n_ctx=2048)
+ return cls._instances[llm_name]
class LlamaCpp(BaseLLM):
-
def __init__(
self,
api_key=None,
@@ -12,41 +25,23 @@ class LlamaCpp(BaseLLM):
*args,
**kwargs,
):
- global llama
- try:
- from llama_cpp import Llama
- except ImportError:
- raise ImportError(
- "Please install llama_cpp using pip install llama-cpp-python"
- )
-
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
- llama = Llama(model_path=llm_name, n_ctx=2048)
+ self.llama = LlamaSingleton.get_instance(llm_name)
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
-
- result = llama(prompt, max_tokens=150, echo=False)
-
- # import sys
- # print(result['choices'][0]['text'].split('### Answer \n')[-1], file=sys.stderr)
-
+ result = self.llama(prompt, max_tokens=150, echo=False)
return result["choices"][0]["text"].split("### Answer \n")[-1]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
-
- result = llama(prompt, max_tokens=150, echo=False, stream=stream)
-
- # import sys
- # print(list(result), file=sys.stderr)
-
+ result = self.llama(prompt, max_tokens=150, echo=False, stream=stream)
for item in result:
for choice in item["choices"]:
- yield choice["text"]
+ yield choice["text"]
\ No newline at end of file
From 5aa88714b86ca2dc2b73416db578469b84e45a42 Mon Sep 17 00:00:00 2001
From: Alex
Date: Tue, 25 Jun 2024 14:41:04 +0100
Subject: [PATCH 2/2] refactor: Add thread lock
---
application/llm/llama_cpp.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py
index a93fa686..804c3c56 100644
--- a/application/llm/llama_cpp.py
+++ b/application/llm/llama_cpp.py
@@ -1,8 +1,10 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
+import threading
class LlamaSingleton:
_instances = {}
+ _lock = threading.Lock() # Add a lock for thread synchronization
@classmethod
def get_instance(cls, llm_name):
@@ -16,6 +18,12 @@ class LlamaSingleton:
cls._instances[llm_name] = Llama(model_path=llm_name, n_ctx=2048)
return cls._instances[llm_name]
+ @classmethod
+ def query_model(cls, llm, prompt, **kwargs):
+ with cls._lock:
+ return llm(prompt, **kwargs)
+
+
class LlamaCpp(BaseLLM):
def __init__(
self,
@@ -34,14 +42,14 @@ class LlamaCpp(BaseLLM):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
- result = self.llama(prompt, max_tokens=150, echo=False)
+ result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False)
return result["choices"][0]["text"].split("### Answer \n")[-1]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
- result = self.llama(prompt, max_tokens=150, echo=False, stream=stream)
+ result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream)
for item in result:
for choice in item["choices"]:
yield choice["text"]
\ No newline at end of file