mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-10 20:41:57 +00:00
194 lines
6.0 KiB
Python
194 lines
6.0 KiB
Python
"""Unit tests for application/llm/llama_cpp.py — LlamaCpp and LlamaSingleton.
|
|
|
|
Covers:
|
|
- LlamaSingleton: get_instance, query_model (thread-safe)
|
|
- LlamaCpp constructor
|
|
- _raw_gen: prompt format and result extraction
|
|
- _raw_gen_stream: streaming iteration
|
|
"""
|
|
|
|
import sys
|
|
import types
|
|
|
|
import pytest
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fake llama_cpp module
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class FakeLlama:
|
|
def __init__(self, model_path=None, n_ctx=None):
|
|
self.model_path = model_path
|
|
self.n_ctx = n_ctx
|
|
self.last_call = None
|
|
|
|
def __call__(self, prompt, **kwargs):
|
|
self.last_call = {"prompt": prompt, **kwargs}
|
|
if kwargs.get("stream"):
|
|
return iter(
|
|
[
|
|
{"choices": [{"text": "chunk1"}]},
|
|
{"choices": [{"text": "chunk2"}]},
|
|
]
|
|
)
|
|
return {"choices": [{"text": "prefix ### Answer \nthe answer"}]}
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def patch_llama_cpp(monkeypatch):
|
|
fake_mod = types.ModuleType("llama_cpp")
|
|
fake_mod.Llama = FakeLlama
|
|
sys.modules["llama_cpp"] = fake_mod
|
|
|
|
# Clear any cached instances
|
|
if "application.llm.llama_cpp" in sys.modules:
|
|
del sys.modules["application.llm.llama_cpp"]
|
|
|
|
yield
|
|
|
|
sys.modules.pop("llama_cpp", None)
|
|
if "application.llm.llama_cpp" in sys.modules:
|
|
del sys.modules["application.llm.llama_cpp"]
|
|
|
|
|
|
@pytest.fixture
|
|
def fresh_singleton():
|
|
from application.llm.llama_cpp import LlamaSingleton
|
|
|
|
LlamaSingleton._instances = {}
|
|
return LlamaSingleton
|
|
|
|
|
|
@pytest.fixture
|
|
def llm(fresh_singleton):
|
|
from application.llm.llama_cpp import LlamaCpp
|
|
|
|
instance = LlamaCpp(api_key="k", user_api_key=None, llm_name="/path/to/model")
|
|
return instance
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LlamaSingleton
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestLlamaSingleton:
|
|
|
|
def test_get_instance_creates_llama(self, fresh_singleton):
|
|
instance = fresh_singleton.get_instance("/model/path")
|
|
assert isinstance(instance, FakeLlama)
|
|
assert instance.model_path == "/model/path"
|
|
|
|
def test_get_instance_caches(self, fresh_singleton):
|
|
inst1 = fresh_singleton.get_instance("/model")
|
|
inst2 = fresh_singleton.get_instance("/model")
|
|
assert inst1 is inst2
|
|
|
|
def test_different_names_different_instances(self, fresh_singleton):
|
|
inst1 = fresh_singleton.get_instance("/model_a")
|
|
inst2 = fresh_singleton.get_instance("/model_b")
|
|
assert inst1 is not inst2
|
|
|
|
def test_query_model_thread_safe(self, fresh_singleton):
|
|
instance = fresh_singleton.get_instance("/model")
|
|
result = fresh_singleton.query_model(instance, "prompt", max_tokens=10)
|
|
assert "choices" in result
|
|
|
|
def test_import_error_raised(self, fresh_singleton, monkeypatch):
|
|
# Remove the fake module to simulate import failure
|
|
sys.modules.pop("llama_cpp", None)
|
|
fresh_singleton._instances = {}
|
|
|
|
with pytest.raises(ImportError, match="llama_cpp"):
|
|
fresh_singleton.get_instance("/new_model")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LlamaCpp constructor
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestLlamaCppConstructor:
|
|
|
|
def test_sets_api_key(self, llm):
|
|
assert llm.api_key == "k"
|
|
|
|
def test_sets_user_api_key(self):
|
|
from application.llm.llama_cpp import LlamaCpp, LlamaSingleton
|
|
|
|
LlamaSingleton._instances = {}
|
|
instance = LlamaCpp(
|
|
api_key="k", user_api_key="uk", llm_name="/path/model"
|
|
)
|
|
assert instance.user_api_key == "uk"
|
|
|
|
def test_creates_llama_instance(self, llm):
|
|
assert isinstance(llm.llama, FakeLlama)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _raw_gen
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestRawGen:
|
|
|
|
def test_returns_answer(self, llm):
|
|
msgs = [
|
|
{"content": "context text"},
|
|
{"content": "user question"},
|
|
]
|
|
result = llm._raw_gen(llm, model="local", messages=msgs)
|
|
assert result == "the answer"
|
|
|
|
def test_prompt_contains_instruction_and_context(self, llm):
|
|
msgs = [
|
|
{"content": "my context"},
|
|
{"content": "my question"},
|
|
]
|
|
llm._raw_gen(llm, model="local", messages=msgs)
|
|
prompt = llm.llama.last_call["prompt"]
|
|
assert "### Instruction" in prompt
|
|
assert "### Context" in prompt
|
|
assert "my question" in prompt
|
|
assert "my context" in prompt
|
|
|
|
def test_max_tokens_passed(self, llm):
|
|
msgs = [{"content": "c"}, {"content": "q"}]
|
|
llm._raw_gen(llm, model="local", messages=msgs)
|
|
assert llm.llama.last_call["max_tokens"] == 150
|
|
assert llm.llama.last_call["echo"] is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _raw_gen_stream
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestRawGenStream:
|
|
|
|
def test_yields_text_chunks(self, llm):
|
|
msgs = [{"content": "c"}, {"content": "q"}]
|
|
chunks = list(llm._raw_gen_stream(llm, model="local", messages=msgs))
|
|
assert chunks == ["chunk1", "chunk2"]
|
|
|
|
def test_prompt_format(self, llm):
|
|
msgs = [{"content": "ctx"}, {"content": "question"}]
|
|
list(llm._raw_gen_stream(llm, model="local", messages=msgs))
|
|
prompt = llm.llama.last_call["prompt"]
|
|
assert "### Instruction" in prompt
|
|
assert "### Answer" in prompt
|
|
|
|
def test_stream_flag_passed(self, llm):
|
|
msgs = [{"content": "c"}, {"content": "q"}]
|
|
list(
|
|
llm._raw_gen_stream(llm, model="local", messages=msgs, stream=True)
|
|
)
|
|
assert llm.llama.last_call["stream"] is True
|