mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-21 16:40:35 +00:00
Compare commits
2 Commits
new_api
...
translatio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa44a92a67 | ||
|
|
01d791470b |
@@ -1,12 +1,13 @@
|
|||||||
from .audio_processor import AudioProcessor
|
from .audio_processor import AudioProcessor
|
||||||
from .core import TranscriptionEngine
|
from .core import TranscriptionEngine
|
||||||
from .parse_args import parse_args
|
from .parse_args import parse_args
|
||||||
from .web.web_interface import get_web_interface_html
|
from .web.web_interface import get_web_interface_html, get_inline_ui_html
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TranscriptionEngine",
|
"TranscriptionEngine",
|
||||||
"AudioProcessor",
|
"AudioProcessor",
|
||||||
"parse_args",
|
"parse_args",
|
||||||
"get_web_interface_html",
|
"get_web_interface_html",
|
||||||
|
"get_inline_ui_html",
|
||||||
"download_simulstreaming_backend",
|
"download_simulstreaming_backend",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from starlette.staticfiles import StaticFiles
|
from starlette.staticfiles import StaticFiles
|
||||||
@@ -38,7 +38,7 @@ app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
|||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get():
|
async def get():
|
||||||
return HTMLResponse(get_web_interface_html())
|
return HTMLResponse(get_inline_ui_html())
|
||||||
|
|
||||||
|
|
||||||
async def handle_websocket_results(websocket, results_generator):
|
async def handle_websocket_results(websocket, results_generator):
|
||||||
|
|||||||
60
whisperlivekit/translate/gemma_translate.py
Normal file
60
whisperlivekit/translate/gemma_translate.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# gemma_translate.py
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
MODEL_ID = "google/gemma-3-270m-it"
|
||||||
|
|
||||||
|
def build_prompt(tokenizer, text, target_lang, source_lang=None):
|
||||||
|
# Use the model's chat template for best results
|
||||||
|
if source_lang:
|
||||||
|
user_msg = (
|
||||||
|
f"Translate the following {source_lang} text into {target_lang}.\n"
|
||||||
|
f"Return only the translation.\n\n"
|
||||||
|
f"Text:\n{text}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
user_msg = (
|
||||||
|
f"Translate the following text into {target_lang}.\n"
|
||||||
|
f"Return only the translation.\n\n"
|
||||||
|
f"Text:\n{text}"
|
||||||
|
)
|
||||||
|
chat = [{"role": "user", "content": user_msg}]
|
||||||
|
return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||||
|
|
||||||
|
def translate(text, target_lang, source_lang=None, max_new_tokens=256, temperature=0.2, top_p=0.95):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = build_prompt(tokenizer, text, target_lang, source_lang)
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output_ids = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
do_sample=temperature > 0.0,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Slice off the prompt to keep only the assistant answer
|
||||||
|
generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
|
||||||
|
out = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
||||||
|
return out
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ap = argparse.ArgumentParser(description="Translate with google/gemma-3-270m-it")
|
||||||
|
ap.add_argument("--text", required=True, help="Text to translate")
|
||||||
|
ap.add_argument("--to", dest="target_lang", required=True, help="Target language (e.g., French, Spanish)")
|
||||||
|
ap.add_argument("--from", dest="source_lang", default=None, help="Source language (optional)")
|
||||||
|
ap.add_argument("--temp", type=float, default=0.2, help="Sampling temperature (0 = deterministic-ish)")
|
||||||
|
ap.add_argument("--max-new", type=int, default=256, help="Max new tokens")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
print(translate(args.text, args.target_lang, args.source_lang, max_new_tokens=args.max_new, temperature=args.temp))
|
||||||
121
whisperlivekit/translate/nllb_translate.py
Normal file
121
whisperlivekit/translate/nllb_translate.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
# nllb_translate.py
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
|
MODEL_ID = "facebook/nllb-200-distilled-600M"
|
||||||
|
|
||||||
|
# Common language shortcuts → NLLB codes (extend as needed)
|
||||||
|
LANG_MAP = {
|
||||||
|
"english": "eng_Latn",
|
||||||
|
"en": "eng_Latn",
|
||||||
|
"french": "fra_Latn",
|
||||||
|
"fr": "fra_Latn",
|
||||||
|
"spanish": "spa_Latn",
|
||||||
|
"es": "spa_Latn",
|
||||||
|
"german": "deu_Latn",
|
||||||
|
"de": "deu_Latn",
|
||||||
|
"italian": "ita_Latn",
|
||||||
|
"it": "ita_Latn",
|
||||||
|
"portuguese": "por_Latn",
|
||||||
|
"pt": "por_Latn",
|
||||||
|
"arabic": "arb_Arab",
|
||||||
|
"ar": "arb_Arab",
|
||||||
|
"russian": "rus_Cyrl",
|
||||||
|
"ru": "rus_Cyrl",
|
||||||
|
"turkish": "tur_Latn",
|
||||||
|
"tr": "tur_Latn",
|
||||||
|
"chinese": "zho_Hans",
|
||||||
|
"zh": "zho_Hans", # Simplified
|
||||||
|
"zh-cn": "zho_Hans",
|
||||||
|
"zh-hans": "zho_Hans",
|
||||||
|
"zh-hant": "zho_Hant", # Traditional
|
||||||
|
"japanese": "jpn_Jpan",
|
||||||
|
"ja": "jpn_Jpan",
|
||||||
|
"korean": "kor_Hang",
|
||||||
|
"ko": "kor_Hang",
|
||||||
|
"dutch": "nld_Latn",
|
||||||
|
"nl": "nld_Latn",
|
||||||
|
"polish": "pol_Latn",
|
||||||
|
"pl": "pol_Latn",
|
||||||
|
"swedish": "swe_Latn",
|
||||||
|
"sv": "swe_Latn",
|
||||||
|
"norwegian": "nob_Latn",
|
||||||
|
"no": "nob_Latn",
|
||||||
|
"danish": "dan_Latn",
|
||||||
|
"da": "dan_Latn",
|
||||||
|
"finnish": "fin_Latn",
|
||||||
|
"fi": "fin_Latn",
|
||||||
|
"catalan": "cat_Latn",
|
||||||
|
"ca": "cat_Latn",
|
||||||
|
"hindi": "hin_Deva",
|
||||||
|
"hi": "hin_Deva",
|
||||||
|
"vietnamese": "vie_Latn",
|
||||||
|
"vi": "vie_Latn",
|
||||||
|
"indonesian": "ind_Latn",
|
||||||
|
"id": "ind_Latn",
|
||||||
|
"thai": "tha_Thai",
|
||||||
|
"th": "tha_Thai",
|
||||||
|
}
|
||||||
|
|
||||||
|
def norm_lang(code: str) -> str:
|
||||||
|
c = code.strip().lower()
|
||||||
|
return LANG_MAP.get(c, code)
|
||||||
|
|
||||||
|
def translate_texts(texts: List[str], src_code: str, tgt_code: str,
|
||||||
|
max_new_tokens=512, device=None, dtype=None) -> List[str]:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, src_lang=src_code)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
torch_dtype=dtype if dtype is not None else (torch.float16 if torch.cuda.is_available() else torch.float32),
|
||||||
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
)
|
||||||
|
if device:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
||||||
|
if device or torch.cuda.is_available():
|
||||||
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
forced_bos = tokenizer.convert_tokens_to_ids(tgt_code)
|
||||||
|
with torch.no_grad():
|
||||||
|
gen = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
forced_bos_token_id=forced_bos,
|
||||||
|
)
|
||||||
|
outs = tokenizer.batch_decode(gen, skip_special_tokens=True)
|
||||||
|
return [o.strip() for o in outs]
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser(description="Translate with facebook/nllb-200-distilled-600M")
|
||||||
|
ap.add_argument("--text", help="Inline text to translate")
|
||||||
|
ap.add_argument("--file", help="Path to a UTF-8 text file (one example per line)")
|
||||||
|
ap.add_argument("--src", required=True, help="Source language (e.g. fr, fra_Latn)")
|
||||||
|
ap.add_argument("--tgt", required=True, help="Target language (e.g. en, eng_Latn)")
|
||||||
|
ap.add_argument("--max-new", type=int, default=512, help="Max new tokens")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
src = norm_lang(args.src)
|
||||||
|
tgt = norm_lang(args.tgt)
|
||||||
|
|
||||||
|
batch: List[str] = []
|
||||||
|
if args.text:
|
||||||
|
batch.append(args.text)
|
||||||
|
if args.file:
|
||||||
|
lines = Path(args.file).read_text(encoding="utf-8").splitlines()
|
||||||
|
batch.extend([ln for ln in lines if ln.strip()])
|
||||||
|
|
||||||
|
if not batch:
|
||||||
|
raise SystemExit("Provide --text or --file")
|
||||||
|
|
||||||
|
results = translate_texts(batch, src, tgt, max_new_tokens=args.max_new)
|
||||||
|
for i, (inp, out) in enumerate(zip(batch, results), 1):
|
||||||
|
print(f"\n--- Sample {i} ---")
|
||||||
|
print(f"SRC [{src}]: {inp}")
|
||||||
|
print(f"TGT [{tgt}]: {out}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
38
whisperlivekit/translate/sentence_segmenter.py
Normal file
38
whisperlivekit/translate/sentence_segmenter.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import regex
|
||||||
|
from functools import lru_cache
|
||||||
|
class SentenceSegmenter:
|
||||||
|
|
||||||
|
"""
|
||||||
|
Regex sentence splitter for Latin languages, Japanese and Chinese.
|
||||||
|
It is based on sacrebleu TokenizerV14International(BaseTokenizer).
|
||||||
|
|
||||||
|
Returns: a list of strings, where each string is a sentence.
|
||||||
|
Spaces following punctuation are appended after punctuation within the sequence.
|
||||||
|
Total number of characters in the output is the same as in the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sep = 'ŽžŽžSentenceSeparatorŽžŽž' # string that certainly won't be in src or target
|
||||||
|
latin_terminals = '!?.'
|
||||||
|
jap_zh_terminals = '。!?'
|
||||||
|
terminals = latin_terminals + jap_zh_terminals
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# end of sentence characters:
|
||||||
|
terminals = self.terminals
|
||||||
|
self._re = [
|
||||||
|
# Separate out punctuations preceeded by a non-digit.
|
||||||
|
# If followed by space-like sequence of characters, they are
|
||||||
|
# appended to the punctuation, not to the next sequence.
|
||||||
|
(regex.compile(r'(\P{N})(['+terminals+r'])(\p{Z}*)'), r'\1\2\3'+self.sep),
|
||||||
|
# Separate out punctuations followed by a non-digit
|
||||||
|
(regex.compile(r'('+terminals+r')(\P{N})'), r'\1'+self.sep+r'\2'),
|
||||||
|
# # Separate out symbols
|
||||||
|
# -> no, we don't tokenize but segment the punctuation
|
||||||
|
# (regex.compile(r'(\p{S})'), r' \1 '),
|
||||||
|
]
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2**16)
|
||||||
|
def __call__(self, line):
|
||||||
|
for (_re, repl) in self._re:
|
||||||
|
line = _re.sub(repl, line)
|
||||||
|
return [ t for t in line.split(self.sep) if t != '' ]
|
||||||
466
whisperlivekit/translate/simul_llm_translate.py
Normal file
466
whisperlivekit/translate/simul_llm_translate.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import ctranslate2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import transformers
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def generate_words(sp, step_results):
|
||||||
|
tokens_buffer = []
|
||||||
|
|
||||||
|
for step_result in step_results:
|
||||||
|
is_new_word = step_result.token.startswith("▁")
|
||||||
|
|
||||||
|
if is_new_word and tokens_buffer:
|
||||||
|
word = sp.decode(tokens_buffer)
|
||||||
|
if word:
|
||||||
|
yield word
|
||||||
|
tokens_buffer = []
|
||||||
|
|
||||||
|
tokens_buffer.append(step_result.token_id)
|
||||||
|
|
||||||
|
if tokens_buffer:
|
||||||
|
word = sp.decode(tokens_buffer)
|
||||||
|
if word:
|
||||||
|
yield word
|
||||||
|
|
||||||
|
from sentence_segmenter import SentenceSegmenter
|
||||||
|
|
||||||
|
class LLMTranslator:
|
||||||
|
|
||||||
|
def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None):
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
print("Loading the model...", file=sys.stderr)
|
||||||
|
self.generator = ctranslate2.Generator("ct2_EuroLLM-9B-Instruct/", device="cuda")
|
||||||
|
self.sp = spm.SentencePieceProcessor("EuroLLM-9B-Instruct/tokenizer.model")
|
||||||
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained("EuroLLM-9B-Instruct/")
|
||||||
|
print("...done", file=sys.stderr)
|
||||||
|
|
||||||
|
self.max_context_length = max_context_length
|
||||||
|
|
||||||
|
self.max_tokens_to_trim = self.max_context_length - 10
|
||||||
|
self.len_ratio = len_ratio
|
||||||
|
|
||||||
|
# my regex sentence segmenter
|
||||||
|
self.segmenter = SentenceSegmenter()
|
||||||
|
|
||||||
|
# self.max_generation_length = 512
|
||||||
|
# self.max_prompt_length = context_length - max_generation_length
|
||||||
|
|
||||||
|
def start_dialog(self):
|
||||||
|
return [{'role':'system', 'content': self.system_prompt }]
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(self, dialog):
|
||||||
|
toks = self.tokenizer.apply_chat_template(dialog, tokenize=True, add_generation_prompt=False)
|
||||||
|
if len(dialog) == 3:
|
||||||
|
toks = toks[:-2]
|
||||||
|
print("len toks:", len(toks), file=sys.stderr)
|
||||||
|
# print(toks, file=sys.stderr)
|
||||||
|
|
||||||
|
c = self.tokenizer.convert_ids_to_tokens(toks)
|
||||||
|
# print(c,file=sys.stderr)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def translate(self, src, tgt_forced=""):
|
||||||
|
#src, tgt_forced = self.trim(src, tgt_forced)
|
||||||
|
|
||||||
|
dialog = self.start_dialog()
|
||||||
|
dialog += [{'role':'user','content': src}]
|
||||||
|
if tgt_forced != "":
|
||||||
|
dialog += [{'role':'assistant','content': tgt_forced}]
|
||||||
|
|
||||||
|
prompt_tokens = self.build_prompt(dialog)
|
||||||
|
if self.len_ratio is not None:
|
||||||
|
limit_len = int(len(self.tokenizer.encode(src)) * self.len_ratio) + 10
|
||||||
|
limit_kw = {'max_length': limit_len}
|
||||||
|
else:
|
||||||
|
limit_kw = {}
|
||||||
|
step_results = self.generator.generate_tokens(
|
||||||
|
prompt_tokens,
|
||||||
|
**limit_kw,
|
||||||
|
# end_token=tokenizer.eos_token,
|
||||||
|
# sampling_temperature=0.6,
|
||||||
|
# sampling_topk=20,
|
||||||
|
# sampling_topp=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
#output_ids = []
|
||||||
|
for step_result in step_results:
|
||||||
|
# is_new_word = step_result.token.startswith("▁")
|
||||||
|
# if is_new_word and output_ids:
|
||||||
|
# word = self.sp.decode(output_ids)
|
||||||
|
# print(word, end=" ", flush=True, file=sys.stderr)
|
||||||
|
# output_ids = []
|
||||||
|
# output_ids.append(step_result.token_id)
|
||||||
|
res.append(step_result)
|
||||||
|
|
||||||
|
#if output_ids:
|
||||||
|
# word = self.sp.decode(output_ids)
|
||||||
|
# print(word, file=sys.stderr)
|
||||||
|
|
||||||
|
return self.sp.decode([r.token_id for r in res])
|
||||||
|
# print(res)
|
||||||
|
# print([s.token for s in res], file=sys.stderr)
|
||||||
|
# print([s.token==self.tokenizer.eos_token for s in res], file=sys.stderr)
|
||||||
|
|
||||||
|
class ParallelTextBuffer:
|
||||||
|
def __init__(self, tokenizer, max_tokens, trimming="segments", init_src="", init_tgt=""):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
self.src_buffer = [] # list of lists
|
||||||
|
if init_src:
|
||||||
|
self.src_buffer.append(init_src)
|
||||||
|
|
||||||
|
self.tgt_buffer = [] # list of strings
|
||||||
|
if init_tgt:
|
||||||
|
self.tgt_buffer.append(init_tgt)
|
||||||
|
|
||||||
|
self.trimming = trimming
|
||||||
|
if self.trimming == "sentences":
|
||||||
|
self.segmenter = SentenceSegmenter()
|
||||||
|
|
||||||
|
def len_src(self):
|
||||||
|
return sum(len(t) for t in self.src_buffer) + len(self.src_buffer) - 1
|
||||||
|
|
||||||
|
def insert(self, src, tgt):
|
||||||
|
self.src_buffer.append(src)
|
||||||
|
self.tgt_buffer.append(tgt)
|
||||||
|
|
||||||
|
def insert_src_suffix(self, s):
|
||||||
|
if self.src_buffer:
|
||||||
|
self.src_buffer[-1][-1] += s
|
||||||
|
else:
|
||||||
|
self.src_buffer.append([s])
|
||||||
|
|
||||||
|
def trim_sentences(self):
|
||||||
|
# src_tok_lens = [len(self.tokenizer.encode(" ".join(b))) for b in self.src_buffer]
|
||||||
|
# tgt_tok_lens = [len(self.tokenizer.encode(t)) for t in self.tgt_buffer]
|
||||||
|
|
||||||
|
src = " ".join(" ".join(b) for b in self.src_buffer)
|
||||||
|
tgt = "".join(self.tgt_buffer)
|
||||||
|
|
||||||
|
src_sp_toks = self.tokenizer.encode(src)
|
||||||
|
tgt_sp_toks = self.tokenizer.encode(tgt)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def trim_sentence(text):
|
||||||
|
sents = self.segmenter(text)
|
||||||
|
print("SENTS:", len(sents), sents, file=sys.stderr)
|
||||||
|
return "".join(sents[1:])
|
||||||
|
|
||||||
|
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
|
||||||
|
nsrc = trim_sentence(src)
|
||||||
|
ntgt = trim_sentence(tgt)
|
||||||
|
if not nsrc or not ntgt:
|
||||||
|
print("src or tgt is empty after trimming.", file=sys.stderr)
|
||||||
|
print("src: ", src, file=sys.stderr)
|
||||||
|
print("tgt: ", tgt, file=sys.stderr)
|
||||||
|
break
|
||||||
|
src = nsrc
|
||||||
|
tgt = ntgt
|
||||||
|
src_sp_toks = self.tokenizer.encode(src)
|
||||||
|
tgt_sp_toks = self.tokenizer.encode(tgt)
|
||||||
|
print("TRIMMED SRC:", (src,), file=sys.stderr)
|
||||||
|
print("TRIMMED TGT:", (tgt,), file=sys.stderr)
|
||||||
|
|
||||||
|
self.src_buffer = [src.split()]
|
||||||
|
self.tgt_buffer = [tgt]
|
||||||
|
return src, tgt
|
||||||
|
|
||||||
|
def trim_segments(self):
|
||||||
|
print("BUFFER:", file=sys.stderr)
|
||||||
|
for s,t in zip(self.src_buffer, self.tgt_buffer):
|
||||||
|
print("\t", s,"...",t,file=sys.stderr) #,self.src_buffer, self.tgt_buffer, file=sys.stderr)
|
||||||
|
src = " ".join(" ".join(b) for b in self.src_buffer)
|
||||||
|
tgt = "".join(self.tgt_buffer)
|
||||||
|
|
||||||
|
src_sp_toks = self.tokenizer.encode(src)
|
||||||
|
tgt_sp_toks = self.tokenizer.encode(tgt)
|
||||||
|
|
||||||
|
while len(src_sp_toks) + len(tgt_sp_toks) > self.max_tokens:
|
||||||
|
if len(self.src_buffer) > 1 and len(self.tgt_buffer) > 1:
|
||||||
|
self.src_buffer.pop(0)
|
||||||
|
self.tgt_buffer.pop(0)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
src = " ".join(" ".join(b) for b in self.src_buffer)
|
||||||
|
tgt = "".join(self.tgt_buffer)
|
||||||
|
|
||||||
|
src_sp_toks = self.tokenizer.encode(src)
|
||||||
|
tgt_sp_toks = self.tokenizer.encode(tgt)
|
||||||
|
|
||||||
|
print("TRIMMED SEGMENTS SRC:", (src,), file=sys.stderr)
|
||||||
|
print("TRIMMED SEGMENTS TGT:", (tgt,), file=sys.stderr)
|
||||||
|
|
||||||
|
return src, tgt
|
||||||
|
|
||||||
|
def trim(self):
|
||||||
|
if self.trimming == "sentences":
|
||||||
|
return self.trim_sentences()
|
||||||
|
return self.trim_segments()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SimulLLM:
|
||||||
|
|
||||||
|
def __init__(self, llmtrans, min_len=0, chunk=1, trimming="sentences", language="ja", init_src="", init_tgt=""):
|
||||||
|
self.llmtranslator = llmtrans
|
||||||
|
|
||||||
|
#self.src_buffer = init_src
|
||||||
|
#self.confirmed_tgt = init_tgt
|
||||||
|
|
||||||
|
self.buffer = ParallelTextBuffer(self.llmtranslator.tokenizer, self.llmtranslator.max_tokens_to_trim, trimming=trimming, init_src=init_src, init_tgt=init_tgt)
|
||||||
|
|
||||||
|
self.last_inserted = []
|
||||||
|
self.last_unconfirmed = ""
|
||||||
|
|
||||||
|
self.min_len = min_len
|
||||||
|
|
||||||
|
self.step = chunk
|
||||||
|
self.language = language
|
||||||
|
if language in ["ja", "zh"]:
|
||||||
|
self.specific_space = ""
|
||||||
|
else:
|
||||||
|
self.specific_space = " "
|
||||||
|
|
||||||
|
def insert(self, src):
|
||||||
|
if isinstance(src, str):
|
||||||
|
self.last_inserted.append(src)
|
||||||
|
else:
|
||||||
|
self.last_inserted += src
|
||||||
|
|
||||||
|
def insert_suffix(self, text):
|
||||||
|
'''
|
||||||
|
Insert suffix of a word to the last inserted word.
|
||||||
|
It may be because the word was split to multiple parts in the input, each with different timestamps.
|
||||||
|
'''
|
||||||
|
if self.last_inserted:
|
||||||
|
self.last_inserted[-1] += text
|
||||||
|
elif self.src_buffer:
|
||||||
|
self.buffer.insert_src_suffix(text)
|
||||||
|
else:
|
||||||
|
# this shouldn't happen
|
||||||
|
self.last_inserted.append(text)
|
||||||
|
|
||||||
|
def trim_longest_common_prefix(self, a,b):
|
||||||
|
if self.language not in ["ja", "zh"]:
|
||||||
|
a = a.split()
|
||||||
|
b = b.split()
|
||||||
|
i = 0
|
||||||
|
for i,(x,y) in enumerate(zip(a,b)):
|
||||||
|
if x != y:
|
||||||
|
break
|
||||||
|
if self.language in ["ja", "zh"]:
|
||||||
|
#print("tady160",(a, b, i), file=sys.stderr)
|
||||||
|
return a[:i], b[i:]
|
||||||
|
else:
|
||||||
|
return " ".join(a[:i]), " ".join(b[i:])
|
||||||
|
|
||||||
|
def process_iter(self):
|
||||||
|
if self.buffer.len_src() + len(self.last_inserted) < self.min_len:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
src, forced_tgt = self.buffer.trim() #llmtranslator.trim(" ".join(self.src_buffer), self.confirmed_tgt)
|
||||||
|
#self.src_buffer = self.src_buffer.split()
|
||||||
|
#src = " ".join(self.src_buffer)
|
||||||
|
|
||||||
|
confirmed_out = ""
|
||||||
|
run = False
|
||||||
|
for i in range(self.step, len(self.last_inserted), self.step):
|
||||||
|
for w in self.last_inserted[i-self.step:i]:
|
||||||
|
src += " " + w
|
||||||
|
run = True
|
||||||
|
if not run: break
|
||||||
|
|
||||||
|
print("SRC",src,file=sys.stderr)
|
||||||
|
|
||||||
|
print("FORCED TGT",forced_tgt,file=sys.stderr)
|
||||||
|
out = self.llmtranslator.translate(src, forced_tgt)
|
||||||
|
print("OUT",out,file=sys.stderr)
|
||||||
|
confirmed, unconfirmed = self.trim_longest_common_prefix(self.last_unconfirmed, out)
|
||||||
|
self.last_unconfirmed = unconfirmed
|
||||||
|
#print("tady", (self.confirmed_tgt, self.specific_space, confirmed), file=sys.stderr)
|
||||||
|
if confirmed:
|
||||||
|
# self.confirmed_tgt += self.specific_space + confirmed
|
||||||
|
# print(confirmed_out, confirmed, file=sys.stderr)
|
||||||
|
confirmed_out += self.specific_space + confirmed
|
||||||
|
print("CONFIRMED NOW:",confirmed,file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
print(file=sys.stderr)
|
||||||
|
print(file=sys.stderr)
|
||||||
|
print("#################",file=sys.stderr)
|
||||||
|
if run:
|
||||||
|
self.buffer.insert(self.last_inserted, confirmed_out)
|
||||||
|
self.last_inserted = []
|
||||||
|
|
||||||
|
ret = confirmed_out
|
||||||
|
print("RET:",ret,file=sys.stderr)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
return self.last_unconfirmed
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--input-instance', type=str, default=None, help="Filename of instances to simulate input. If not set, txt input is read from stdin.")
|
||||||
|
#parser.add_argument('--output_instance', type=str, default=None, help="Write output as instance into this file, while also writing to stdout.")
|
||||||
|
parser.add_argument('--min-chunk-size', type=int, default=1,
|
||||||
|
help='Minimum number of space-delimited words to process in each LocalAgreement update. The more, the higher quality, but slower.')
|
||||||
|
parser.add_argument('--min-len', type=int, default=1,
|
||||||
|
help='Minimum number of space-delimited words at the beginning.')
|
||||||
|
#parser.add_argument('--start_at', type=int, default=0, help='Skip first N words.')
|
||||||
|
|
||||||
|
# maybe later
|
||||||
|
#parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
|
||||||
|
#parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
|
||||||
|
|
||||||
|
lan_to_name = {
|
||||||
|
"de": "German",
|
||||||
|
"ja": "Japanese",
|
||||||
|
"zh-tr": "Chinese Traditional",
|
||||||
|
"zh-sim": "Chinese Simplified",
|
||||||
|
"cs": "Czech",
|
||||||
|
}
|
||||||
|
parser.add_argument('--lan', '--language', type=str, default="de",
|
||||||
|
help="Target language code.",
|
||||||
|
choices=["de", "ja","zh-tr","zh-sim","cs"])
|
||||||
|
|
||||||
|
SrcLang = "English" # always
|
||||||
|
TgtLang = "German"
|
||||||
|
default_prompt="You are simultaneous interpreter from {SrcLang} to {TgtLang}. We are at a conference. It is important that you translate " + \
|
||||||
|
"only what you hear, nothing else!"
|
||||||
|
parser.add_argument('--sys_prompt', type=str, default=None,
|
||||||
|
help='System prompt. If None, default one is used, depending on the language. The prompt should ')
|
||||||
|
|
||||||
|
default_init = "Please, go ahead, you can start with your presentation, we are ready."
|
||||||
|
|
||||||
|
|
||||||
|
default_inits_tgt = {
|
||||||
|
'de': "Bitte schön, Sie können mit Ihrer Präsentation beginnen, wir sind bereit.",
|
||||||
|
'ja': "どうぞ、プレゼンテーションを始めてください。", # # Please go ahead and start your presentation. # this is in English
|
||||||
|
'zh-tr': "請繼續,您可以開始您的簡報,我們已經準備好了。",
|
||||||
|
'zh-sim': "请吧,你可以开始发言了,我们已经准备好了。",
|
||||||
|
'cs': "Prosím, můžete začít s prezentací, jsme připraveni.",
|
||||||
|
}
|
||||||
|
parser.add_argument('--init_prompt_src', type=str, default=None, help='Init translation with source text. It should be a complete sentence in the source language. '
|
||||||
|
'It can be context specific for the given input. Default is ')
|
||||||
|
parser.add_argument('--init_prompt_tgt', type=str, default=None, help='Init translation with this target. It should be example translation of init_prompt_src. '
|
||||||
|
' There is default init message, depending on the language.')
|
||||||
|
|
||||||
|
parser.add_argument('--len-threshold', type=float, default=None, help='Ratio of the length of the source and generated target, in number of sentencepiece tokens. '
|
||||||
|
'It should reflect the target language and. If not set, no len-threshold is used.')
|
||||||
|
|
||||||
|
# how many times is target text longer than English
|
||||||
|
lan_thresholds = {
|
||||||
|
'de': 1.3, # 12751/9817 ... the proportion of subword tokens for ACL6060 dev de vs. en text, for EuroLLM-9B-Instruct tokenizer
|
||||||
|
'ja': 1.34, # 13187/9817
|
||||||
|
'zh': 1.23, # 12115/9817
|
||||||
|
'zh-tr': 1.23, # 12115/9817
|
||||||
|
'zh-sim': 1.23, # 12115/9817
|
||||||
|
# 'cs': I don't know # guessed
|
||||||
|
}
|
||||||
|
parser.add_argument('--language-specific-len-threshold', default=False, action="store_true",
|
||||||
|
help='Use language-specific length threshold, e.g. 1.3 for German.')
|
||||||
|
|
||||||
|
parser.add_argument("--max-context-length", type=int, default=4096, help="Maximum number of tokens in the model to use.")
|
||||||
|
|
||||||
|
parser.add_argument("--buffer_trimming", type=str, default="sentences", choices=["segments","sentences"], help="Buffer trimming strategy.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.sys_prompt is None:
|
||||||
|
TgtLang = lan_to_name[args.lan]
|
||||||
|
sys_prompt = default_prompt.format(SrcLang=SrcLang, TgtLang=TgtLang)
|
||||||
|
else:
|
||||||
|
sys_prompt = args.sys_prompt
|
||||||
|
|
||||||
|
if args.init_prompt_src is None:
|
||||||
|
init_src = default_init.split()
|
||||||
|
if args.init_prompt_tgt is None:
|
||||||
|
init_tgt = default_inits_tgt[args.lan]
|
||||||
|
if args.lan == "ja":
|
||||||
|
init_src = 'Please go ahead and start your presentation.'.split()
|
||||||
|
print("WARNING: Default init_prompt_src not set and language is Japanese. The init_src prompt changed to be more verbose.", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
print("WARNING: init_prompt_tgt is used, init_prompt_src is None, the default one. It may be wrong!", file=sys.stderr)
|
||||||
|
init_tgt = args.init_prompt_tgt
|
||||||
|
else:
|
||||||
|
init_src = args.init_prompt_src.split()
|
||||||
|
if args.init_prompt_tgt is None:
|
||||||
|
print("WARNING: init_prompt_src is used, init_prompt_tgt is None, so the default one is used. It may be wrong!", file=sys.stderr)
|
||||||
|
init_tgt = default_inits_tgt[args.lan]
|
||||||
|
else:
|
||||||
|
init_tgt = args.init_prompt_tgt
|
||||||
|
|
||||||
|
print("INFO: System prompt:", sys_prompt, file=sys.stderr)
|
||||||
|
print("INFO: Init prompt src:", init_src, file=sys.stderr)
|
||||||
|
print("INFO: Init prompt tgt:", init_tgt, file=sys.stderr)
|
||||||
|
|
||||||
|
if args.language_specific_len_threshold:
|
||||||
|
if args.len_threshold is not None:
|
||||||
|
print("ERROR: --len-threshold is set, but --language-specific-len-threshold is also set. Only one can be used.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
len_threshold = lan_thresholds[args.lan]
|
||||||
|
else:
|
||||||
|
len_threshold = args.len_threshold
|
||||||
|
|
||||||
|
llmtrans = LLMTranslator(system_prompt=sys_prompt, max_context_length=args.max_context_length, len_ratio=len_threshold)
|
||||||
|
lan = args.lan if not args.lan.startswith("zh") else "zh"
|
||||||
|
simul = SimulLLM(llmtrans,language=lan, min_len=args.min_len, chunk=args.min_chunk_size,
|
||||||
|
init_src=init_src, init_tgt=init_tgt, trimming=args.buffer_trimming
|
||||||
|
)
|
||||||
|
|
||||||
|
# two input options
|
||||||
|
if args.input_instance is not None:
|
||||||
|
print("INFO: Reading input from file", args.input_instance, file=sys.stderr)
|
||||||
|
import json
|
||||||
|
with open(args.input_instance, "r") as f:
|
||||||
|
instance = json.load(f)
|
||||||
|
|
||||||
|
asr_source = instance["prediction"]
|
||||||
|
timestamps = instance["delays"]
|
||||||
|
elapsed = instance["elapsed"]
|
||||||
|
|
||||||
|
yield_ts_words = zip(timestamps, timestamps, elapsed, asr_source.split())
|
||||||
|
else:
|
||||||
|
print("INFO: Reading stdin in txt format", file=sys.stderr)
|
||||||
|
def yield_input():
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip()
|
||||||
|
ts, beg, end, *_ = line.split()
|
||||||
|
text = line[len(ts)+len(beg)+len(end)+3:]
|
||||||
|
ts = float(ts)
|
||||||
|
# in rare cases, the first word is a suffix of the previous word, that was split to multiple parts
|
||||||
|
if text[0] != " ":
|
||||||
|
first, *words = text.split()
|
||||||
|
yield (ts, beg, end, " "+first) # marking the first word with " ", so that it can be later detected and inserted as suffix
|
||||||
|
else:
|
||||||
|
words = text.split()
|
||||||
|
for w in words:
|
||||||
|
yield (ts, beg, end, w)
|
||||||
|
yield_ts_words = yield_input()
|
||||||
|
|
||||||
|
#i = 0
|
||||||
|
for t,b,e,w in yield_ts_words:
|
||||||
|
if w.startswith(" "): # it is suffix of the previous word
|
||||||
|
w = w[1:]
|
||||||
|
simul.insert_suffix(w)
|
||||||
|
continue
|
||||||
|
simul.insert(w)
|
||||||
|
out = simul.process_iter()
|
||||||
|
if out:
|
||||||
|
print(t,b,e,out,flush=True)
|
||||||
|
# if i > 50:
|
||||||
|
# break
|
||||||
|
# i += 1
|
||||||
|
out = simul.finalize()
|
||||||
|
print(t,b,e,out,flush=True)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import importlib.resources as resources
|
import importlib.resources as resources
|
||||||
|
import base64
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -12,6 +13,67 @@ def get_web_interface_html():
|
|||||||
logger.error(f"Error loading web interface HTML: {e}")
|
logger.error(f"Error loading web interface HTML: {e}")
|
||||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||||
|
|
||||||
|
def get_inline_ui_html():
|
||||||
|
"""Returns the complete web interface HTML with all assets embedded in a single call."""
|
||||||
|
try:
|
||||||
|
# Load HTML template
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||||
|
html_content = f.read()
|
||||||
|
|
||||||
|
# Load CSS and embed it
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
|
||||||
|
css_content = f.read()
|
||||||
|
|
||||||
|
# Load JS and embed it
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
|
||||||
|
js_content = f.read()
|
||||||
|
|
||||||
|
# Load SVG files and convert to data URIs
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
|
||||||
|
system_svg = f.read()
|
||||||
|
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
|
||||||
|
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
|
||||||
|
light_svg = f.read()
|
||||||
|
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
|
||||||
|
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
|
||||||
|
dark_svg = f.read()
|
||||||
|
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
|
||||||
|
|
||||||
|
# Replace external references with embedded content
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<link rel="stylesheet" href="/web/live_transcription.css" />',
|
||||||
|
f'<style>\n{css_content}\n</style>'
|
||||||
|
)
|
||||||
|
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<script src="/web/live_transcription.js"></script>',
|
||||||
|
f'<script>\n{js_content}\n</script>'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace SVG references with data URIs
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<img src="/web/src/system_mode.svg" alt="" />',
|
||||||
|
f'<img src="{system_data_uri}" alt="" />'
|
||||||
|
)
|
||||||
|
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<img src="/web/src/light_mode.svg" alt="" />',
|
||||||
|
f'<img src="{light_data_uri}" alt="" />'
|
||||||
|
)
|
||||||
|
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<img src="/web/src/dark_mode.svg" alt="" />',
|
||||||
|
f'<img src="{dark_data_uri}" alt="" />'
|
||||||
|
)
|
||||||
|
|
||||||
|
return html_content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating embedded web interface: {e}")
|
||||||
|
return "<html><body><h1>Error loading embedded interface</h1></body></html>"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
@@ -28,6 +90,6 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get():
|
async def get():
|
||||||
return HTMLResponse(get_web_interface_html())
|
return HTMLResponse(get_inline_ui_html())
|
||||||
|
|
||||||
uvicorn.run(app=app)
|
uvicorn.run(app=app)
|
||||||
|
|||||||
Reference in New Issue
Block a user