faster-whisper support

This commit is contained in:
Dominik Macháček
2023-04-19 15:31:35 +02:00
parent 6dc5cdf330
commit 8116b21b4e
2 changed files with 104 additions and 15 deletions

View File

@@ -3,19 +3,24 @@ Whisper realtime streaming for long speech-to-text transcription and translation
## Installation
This code work with two kinds of backends. Both require
```
pip install git+https://github.com/linto-ai/whisper-timestamped
XDG_CACHE_HOME=$(pwd)/pip-cache pip install git+https://github.com/linto-ai/whisper-timestamped
pip install librosa
pip install opus-fast-mosestokenizer
pip install torch
```
The most recommended backend is [faster-whisper](https://github.com/guillaumekln/faster-whisper) with GPU support. Follow their instructions for NVIDIA libraries -- we succeeded with CUDNN 8.5.0 and CUDA 11.7. Install with `pip install faster-whisper`.
Alternative, less restrictive, but slowe backend is [whisper-timestamped](https://github.com/linto-ai/whisper-timestamped): `pip install git+https://github.com/linto-ai/whisper-timestamped`
The backend is loaded only when chosen. The unused one does not have to be installed.
## Usage
```
(p3) $ python3 whisper_online.py -h
usage: whisper_online.py [-h] [--min-chunk-size MIN_CHUNK_SIZE] [--model MODEL] [--model_dir MODEL_DIR] [--lan LAN] [--start_at START_AT] audio_path
usage: whisper_online.py [-h] [--min-chunk-size MIN_CHUNK_SIZE] [--model MODEL] [--model_dir MODEL_DIR] [--lan LAN] [--start_at START_AT] [--backend {faster-whisper,whisper_timestamped}] audio_path
positional arguments:
audio_path
@@ -30,6 +35,8 @@ options:
--lan LAN, --language LAN
Language code for transcription, e.g. en,de,cs.
--start_at START_AT Start processing audio at this time.
--backend {faster-whisper,whisper_timestamped}
Load only this backend for Whisper processing.
```
Example:

View File

@@ -1,15 +1,10 @@
#!/usr/bin/env python3
import sys
import numpy as np
import whisper
import whisper_timestamped
import librosa
import librosa
from functools import lru_cache
import torch
import time
from mosestokenizer import MosesTokenizer
import json
@lru_cache
def load_audio(fname):
@@ -22,10 +17,38 @@ def load_audio_chunk(fname, beg, end):
end_s = int(end*16000)
return audio[beg_s:end_s]
class WhisperASR:
def __init__(self, modelsize="small", lan="en", cache_dir="disk-cache-dir"):
# Whisper backend
class ASRBase:
def __init__(self, modelsize, lan, cache_dir):
self.original_language = lan
self.model = whisper.load_model(modelsize, download_root=cache_dir)
self.model = self.load_model(modelsize, cache_dir)
def load_model(self, modelsize, cache_dir):
raise NotImplemented("mus be implemented in the child class")
def transcribe(self, audio, init_prompt=""):
raise NotImplemented("mus be implemented in the child class")
## requires imports:
# import whisper
# import whisper_timestamped
class WhisperTimestampedASR(ASRBase):
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
On the other hand, the installation for GPU could be easier.
If used, requires imports:
import whisper
import whisper_timestamped
"""
def load_model(self, modelsize, cache_dir):
return whisper.load_model(modelsize, download_root=cache_dir)
def transcribe(self, audio, init_prompt=""):
result = whisper_timestamped.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True)
@@ -40,6 +63,52 @@ class WhisperASR:
o.append(t)
return o
def segments_end_ts(self, res):
return [s["end"] for s in res["segments"]]
class FasterWhisperASR(ASRBase):
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
Requires imports, if used:
import faster_whisper
"""
def load_model(self, modelsize, cache_dir):
# cache_dir is not set, it seemed not working. Default ~/.cache/huggingface/hub is used.
# this worked fast and reliably on NVIDIA L40
model = WhisperModel(modelsize, device="cuda", compute_type="float16")
# or run on GPU with INT8
# tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
#model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# tested: works, but slow, appx 10-times than cuda FP16
#model = WhisperModel(model_size, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
return model
def transcribe(self, audio, init_prompt=""):
wt = False
segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True)
return list(segments)
def ts_words(self, segments):
o = []
for segment in segments:
for word in segment.words:
# stripping the spaces
w = word.word.strip()
t = (word.start, word.end, w)
o.append(t)
return o
def segments_end_ts(self, res):
return [s.end for s in res]
def to_flush(sents, offset=0):
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
@@ -253,7 +322,7 @@ class OnlineASRProcessor:
def chunk_completed_segment(self, res):
if self.commited == []: return
ends = [s["end"] for s in res["segments"]]
ends = self.asr.segments_end_ts(res)
t = self.commited[-1][1]
@@ -320,6 +389,7 @@ class OnlineASRProcessor:
## main:
import argparse
@@ -330,6 +400,7 @@ parser.add_argument('--model', type=str, default='large-v2', help="name of the W
parser.add_argument('--model_dir', type=str, default='disk-cache-dir', help="the path where Whisper models are saved (or downloaded to). Default: ./disk-cache-dir")
parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
args = parser.parse_args()
audio_path = args.audio_path
@@ -343,7 +414,18 @@ language = args.lan
t = time.time()
print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
asr = WhisperASR(lan=language, modelsize=size)
#asr = WhisperASR(lan=language, modelsize=size)
if args.backend == "faster-whisper":
from faster_whisper import WhisperModel
asr_cls = FasterWhisperASR
else:
import whisper
import whisper_timestamped
# from whisper_timestamped_model import WhisperTimestampedASR
asr_cls = WhisperTimestampedASR
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_dir)
e = time.time()
print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)