add doctest

This commit is contained in:
Silas Kieser
2025-01-28 23:17:21 +01:00
parent b1f7034577
commit 15205f31d1

View File

@@ -1,33 +1,55 @@
from typing import List
import numpy as np
from collections import namedtuple
class TimeStampedSegment:
"""
Represents a segment of text with start and end timestamps.
class TimeStampedSegment():
def __init__(self, start=None, end=None, text=""):
Attributes:
start (float): The start time of the segment.
end (float): The end time of the segment.
text (str): The text of the segment.
"""
def __init__(self, start: float, end: float, text: str):
self.start = start
self.end = end
self.text = text
def __getitem__(self, key):
if key == 0:
return self.start
elif key == 1:
return self.end
elif key == 2:
return self.text
elif isinstance(key, slice):
raise NotImplementedError('Slicing not supported')
def __str__(self):
return f'{self.start} - {self.end}: {self.text}'
def __repr__(self):
return self.__str__()
def shift(self, shift):
def shift(self, shift: float):
"""
Shifts the segment by a given amount of time.
Args:
shift (float): The amount of time to shift the segment.
Returns:
TimeStampedSegment: A new segment shifted by the given amount of time.
Example:
>>> segment = TimeStampedSegment(0.0, 1.0, "Hello")
>>> segment.shift(1.0)
1.0 - 2.0: Hello
"""
return TimeStampedSegment(self.start + shift, self.end + shift, self.text)
def append_text(self, text):
def append_text(self, text: str):
"""
Appends text to the segment.
Args:
text (str): The text to append.
Example:
>>> segment = TimeStampedSegment(0.0, 1.0, "Hello")
>>> segment.append_text("!")
>>> segment
0.0 - 1.0: Hello!
"""
self.text += text
def __eq__(self, other):
@@ -40,56 +62,47 @@ class TimeStampedSegment():
return TimeStampedSegment(self.start, self.end, self.text + other)
else:
raise TypeError(f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'")
class TimeStampedText(list):
class TimeStampedText:
"""
Represents a collection of TimeStampedSegment instances.
def __init__(self, time_stamped_segments: list[TimeStampedSegment]):
super().__init__(time_stamped_segments)
self._index = 0
Attributes:
segments (List[TimeStampedSegment]): The list of segments.
"""
def __init__(self):
self.segments: List[TimeStampedSegment] = []
def words(self):
return [segment.text for segment in self]
def starts(self):
return [segment.start for segment in self]
def ends(self):
return [segment.end for segment in self]
def concatenate(self, sep:str, offset=0)->TimeStampedSegment:
def add_segment(self, segment: TimeStampedSegment):
"""
Concatenates the timestamped words or sentences into a single sequence with timing information.
This method joins all words in the sequence using the specified separator and preserves
the timing information from the first to the last word.
Adds a segment to the collection.
Args:
sep (str): Separator string used to join the words together
offset (float, optional): Time offset to add to begin/end timestamps. Defaults to 0.
Returns:
TimeStampedSegment: A new segment containing:
- Start time: First word's start time + offset
- End time: Last word's end time + offset
- Text: All words joined by separator
Examples:
>>> seg = TimeStampedSegment([(1.0, 2.0, "hello"), (2.1, 3.0, "world!")])
>>> result = seg.concatenate(" ")
>>> print(result)
(1.0, 3.0, "hello world!")
Notes:
Returns an empty TimeStampedSegment if the current segment contains no words.
segment (TimeStampedSegment): The segment to add.
Example:
>>> tst = TimeStampedText()
>>> tst.add_segment(TimeStampedSegment(0.0, 1.0, "Hello"))
>>> tst.add_segment(TimeStampedSegment(1.0, 2.0, "world"))
>>> len(tst)
2
"""
self.segments.append(segment)
if len(self) == 0:
return TimeStampedSegment()
def __repr__(self):
return f"TimeStampedText(segments={self.segments})"
combined_text = sep.join(self.words())
b = offset + self[0][0]
e = offset + self[-1][1]
return TimeStampedSegment(b, e, combined_text)
def __iter__(self):
return iter(self.segments)
def __getitem__(self, index):
return self.segments[index]
def __len__(self):
return len(self.segments)
# TODO: a function from_whisper_res()
if __name__ == "__main__":
import doctest
doctest.testmod(verbose=True)