mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
This commit is contained in:
@@ -16,12 +16,25 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
||||
class DummyClient:
|
||||
def __init__(self, api_key):
|
||||
created["api_key"] = api_key
|
||||
self.generate_calls = []
|
||||
self.convert_calls = []
|
||||
|
||||
def generate(self, *, text, model, voice):
|
||||
self.generate_calls.append({"text": text, "model": model, "voice": voice})
|
||||
yield b"chunk-one"
|
||||
yield b"chunk-two"
|
||||
class TextToSpeech:
|
||||
def __init__(self, outer):
|
||||
self._outer = outer
|
||||
|
||||
def convert(self, *, voice_id, model_id, text, output_format):
|
||||
self._outer.convert_calls.append(
|
||||
{
|
||||
"voice_id": voice_id,
|
||||
"model_id": model_id,
|
||||
"text": text,
|
||||
"output_format": output_format,
|
||||
}
|
||||
)
|
||||
yield b"chunk-one"
|
||||
yield b"chunk-two"
|
||||
|
||||
self.text_to_speech = TextToSpeech(self)
|
||||
|
||||
client_module = ModuleType("elevenlabs.client")
|
||||
client_module.ElevenLabs = DummyClient
|
||||
@@ -35,8 +48,13 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
||||
audio_base64, lang = tts.text_to_speech("Speak")
|
||||
|
||||
assert created["api_key"] == "api-key"
|
||||
assert tts.client.generate_calls == [
|
||||
{"text": "Speak", "model": "eleven_multilingual_v2", "voice": "Brian"}
|
||||
assert tts.client.convert_calls == [
|
||||
{
|
||||
"voice_id": "nPczCjzI2devNBz1zQrb",
|
||||
"model_id": "eleven_multilingual_v2",
|
||||
"text": "Speak",
|
||||
"output_format": "mp3_44100_128",
|
||||
}
|
||||
]
|
||||
assert lang == "en"
|
||||
assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two"
|
||||
|
||||
61
tests/tts/test_tts_creator.py
Normal file
61
tests/tts/test_tts_creator.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts_creator():
|
||||
return TTSCreator()
|
||||
|
||||
|
||||
def test_create_google_tts(tts_creator):
|
||||
# Patch the provider registry so the factory calls our mock class
|
||||
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||
instance = MagicMock()
|
||||
mock_google_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("google_tts", "arg1", key="value")
|
||||
|
||||
mock_google_tts.assert_called_once_with("arg1", key="value")
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_create_elevenlabs_tts(tts_creator):
|
||||
# Patch the provider registry so the factory calls our mock class
|
||||
with patch.dict(TTSCreator.tts_providers, {"elevenlabs": MagicMock()}):
|
||||
mock_elevenlabs_tts = TTSCreator.tts_providers["elevenlabs"]
|
||||
instance = MagicMock()
|
||||
mock_elevenlabs_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("elevenlabs", "voice", lang="en")
|
||||
|
||||
mock_elevenlabs_tts.assert_called_once_with("voice", lang="en")
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_invalid_tts_type(tts_creator):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tts_creator.create_tts("unknown_tts")
|
||||
assert "No tts class found" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_tts_type_case_insensitivity(tts_creator):
|
||||
# Patch the provider registry to ensure case-insensitive lookup hits our mock
|
||||
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||
instance = MagicMock()
|
||||
mock_google_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("GoOgLe_TtS")
|
||||
|
||||
mock_google_tts.assert_called_once_with()
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_tts_providers_integrity(tts_creator):
|
||||
providers = tts_creator.tts_providers
|
||||
assert "google_tts" in providers
|
||||
assert "elevenlabs" in providers
|
||||
assert callable(providers["google_tts"])
|
||||
assert callable(providers["elevenlabs"])
|
||||
Reference in New Issue
Block a user