Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 34 additions & 72 deletions src/speechbox/diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read

from .utils.diarize_utils import match_segments


class ASRDiarizationPipeline:
def __init__(
Expand All @@ -34,7 +36,7 @@ def from_pretrained(
"automatic-speech-recognition",
model=asr_model,
chunk_length_s=chunk_length_s,
token=use_auth_token, # 08/25/2023: Changed argument from use_auth_token to token
token=use_auth_token, # 08/25/2023: Changed argument from use_auth_token to token
**kwargs,
)
diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
Expand All @@ -43,18 +45,16 @@ def from_pretrained(
def __call__(
self,
inputs: Union[np.ndarray, List[np.ndarray]],
group_by_speaker: bool = True,
iou_threshold: float = 0.0,
**kwargs,
):
) -> list[dict]:
"""
Transcribe the audio sequence(s) given as inputs to text and label with speaker information. The input audio
is first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke when'. The audio
is then passed to the ASR pipeline, which returns utterance-level transcriptions and their corresponding
timestamps. The speaker diarizer timestamps are aligned with the ASR transcription timestamps to give
speaker-labelled transcriptions. We cannot use the speaker diarization timestamps alone to partition the
transcriptions, as these timestamps may straddle across transcribed utterances from the ASR output. Thus, we
find the diarizer timestamps that are closest to the ASR timestamps and partition here.

speaker-labelled transcriptions. We perform a best intersection over union (IoU) to select the best match between
the speaker diarizer segment and the ASR transcription segment.
Args:
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
The inputs is either :
Expand All @@ -69,103 +69,65 @@ def __call__(
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
inference to provide more context to the model). Only use `stride` with CTC models.
group_by_speaker (`bool`):
Whether to group consecutive utterances by one speaker into a single segment. If False, will return
transcriptions on a chunk-by-chunk basis.
iou_threshold (float):
The threshold under which an IoU is considere too low.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update additional asr or diarization configuration parameters
- To update the asr configuration, use the prefix *asr_* for each configuration parameter.
- To update the diarization configuration, use the prefix *diarization_* for each configuration parameter.
- Added this support related to issue #25: 08/25/2023

Return:
A list of transcriptions. Each list item corresponds to one chunk / segment of transcription, and is a
dictionary with the following keys:
- **text** (`str` ) -- The recognized text.
- **speaker** (`str`) -- The associated speaker.
- **timestamps** (`tuple`) -- The start and end time for the chunk / segment.

Note:
If no match occur between the speaker diarizer segment and the ASR transcription segment, a `NO_SPEAKER` label
will be assign as we can't infer properly the speaker of the segment.
"""
kwargs_asr = {
argument[len("asr_") :]: value for argument, value in kwargs.items() if argument.startswith("asr_")
}

kwargs_diarization = {
argument[len("diarization_") :]: value for argument, value in kwargs.items() if argument.startswith("diarization_")
argument[len("diarization_") :]: value
for argument, value in kwargs.items()
if argument.startswith("diarization_")
}

inputs, diarizer_inputs = self.preprocess(inputs)

diarization = self.diarization_pipeline(
{"waveform": diarizer_inputs, "sample_rate": self.sampling_rate},
**kwargs_diarization,
)

segments = []
for segment, track, label in diarization.itertracks(yield_label=True):
segments.append({'segment': {'start': segment.start, 'end': segment.end},
'track': track,
'label': label})

# diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
# we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
new_segments = []
prev_segment = cur_segment = segments[0]

for i in range(1, len(segments)):
cur_segment = segments[i]

# check if we have changed speaker ("label")
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
# add the start/end times for the super-segment to the new list
new_segments.append(
{
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
"speaker": prev_segment["label"],
}
)
prev_segment = segments[i]

# add the last segment(s) if there was no speaker change
new_segments.append(
{
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
"speaker": prev_segment["label"],
}
)
dia_seg, dia_label = [], []
for segment, _, label in diarization.itertracks(yield_label=True):
dia_seg.append([segment.start, segment.end])
dia_label.append(label)

assert (
dia_seg
), "The result from the diarization pipeline: `diarization_segments` is empty. No segments found from the diarization process."

asr_out = self.asr_pipeline(
{"array": inputs, "sampling_rate": self.sampling_rate},
return_timestamps=True,
**kwargs_asr,
)
transcript = asr_out["chunks"]

# get the end timestamps for each chunk from the ASR output
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
segmented_preds = []

# align the diarizer timestamps and the ASR timestamps
for segment in new_segments:
# get the diarizer end timestamp
end_time = segment["segment"]["end"]
# find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
upto_idx = np.argmin(np.abs(end_timestamps - end_time))

if group_by_speaker:
segmented_preds.append(
{
"speaker": segment["speaker"],
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
"timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]),
}
)
else:
for i in range(upto_idx + 1):
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
segmented_preds = asr_out["chunks"]

dia_seg = np.array(dia_seg)
asr_seg = np.array([[*chunk["timestamp"]] for chunk in segmented_preds])

asr_labels = match_segments(dia_seg, dia_label, asr_seg, threshold=iou_threshold, no_match_label="NO_SPEAKER")

# crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
transcript = transcript[upto_idx + 1 :]
end_timestamps = end_timestamps[upto_idx + 1 :]
for i, label in enumerate(asr_labels):
segmented_preds[i]["speaker"] = label

return segmented_preds

Expand Down
95 changes: 95 additions & 0 deletions src/speechbox/utils/diarize_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np


def IoU(diarized_segments: np.ndarray, asr_segments: np.ndarray) -> np.ndarray:
"""
Calculates the Intersection over Union (IoU) between diarized_segments and asr_segments.

Args:
-----------
- diarized_segments (np.ndarray): An array representing N segments with shape (M, 2), where each row
contains the start and end times of a diarized segment.
- asr_segments (np.ndarray): An array representing M segments with shape (N, 2), where each row contains
the start and end times of an asr segment.

Returns:
--------
- np.ndarray: A 2D array of shape (N, M) representing the IoU between each pair of diarized and.
The value at position (i, j) in the array corresponds to the IoU between the asr segment i and the diarized segment j.
Values are in the range [0, 1], where 0 indicates no intersection and 1 indicates perfect overlap.

Note:
- The IoU is calculated as the ratio of the intersection over the union of the time intervals.
- Segments with no overlap result in an IoU value of 0.
- Segments with overlap but no intersection (e.g., one segment completely contained within another) can
have an IoU greater than 0.

Example:
```python
diarized_segments = np.array([[0, 5], [3, 8], [6, 10]])
asr_segments = np.array([[2, 6], [1, 4]])

IoU_values = IoU(diarized_segments, asr_segments)
print(IoU_values)
# Output
# [[0.5 0.5 0.]
# [0.6 0.14285714 0.]]
```
"""
# We measure intersection between each of the N asr_segments [Nx2] and each M of diarize_ segments [Mx2]
# The result is a NxM matrix. intersection <= 0 mean no intersection.
starts = np.maximum(asr_segments[:, 0, np.newaxis], diarized_segments[:, 0])
ends = np.minimum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1])
intersections = np.maximum(ends - starts, 0)

# Union for segments without overlap will lead to invalid results but it does not matters
# as we opt them out eventually.
union = np.maximum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1]) - np.minimum(
asr_segments[:, 0, np.newaxis], diarized_segments[:, 0]
)

# Negative results are zeroed as they are invalid.
intersection_over_union = np.maximum(intersections / union, 0)

return intersection_over_union


def match_segments(
diarized_segments: np.ndarray,
diarized_labels: list[str],
asr_segments: np.ndarray,
threshold: float = 0.0,
no_match_label: str = "NO_SPEAKER",
) -> np.ndarray:
"""
Perform segment matching between diarized segments and ASR (Automatic Speech Recognition) segments.

Args:
-----
- diarized_segments (np.ndarray): Array representing diarized speaker segments.
- diarized_labels (list[str]): List of labels corresponding to diarized_segments.
- asr_segments (np.ndarray): Array representing ASR speaker segments.
- threshold (float, optional): IoU (Intersection over Union) threshold for matching. Default is 0.0.
- no_match_label (str, optional): Label assigned when no matching segment is found. Default is "NO_SPEAKER".

Returns:
--------
- np.ndarray: Array of labels corresponding to the best-matched ASR segments for each diarized segment.

Notes:
- The function calculates IoU between diarized segments and ASR segments and considers only segments with IoU above the threshold.
- If no matching segment is found, the specified `no_match_label` is assigned.
- The returned array represents the labels of the best-matched ASR segments for each diarized segment.
"""
iou_results = IoU(diarized_segments, asr_segments)
# Zero out iou below threshold.
iou_results[iou_results <= threshold] = 0.0
# We create a no match label which value will be threshold
diarized_labels = [no_match_label] + diarized_labels
# If there is nothing above threshold, no_match_label will be assigned.
iou_results = np.hstack([threshold * np.ones((iou_results.shape[0], 1)), iou_results])
# Will find argument with highest iou (if all zeroes, will assign first (no_match_label)).
best_match_idx = np.argmax(iou_results, axis=1)
assigned_labels = np.take(diarized_labels, best_match_idx)

return assigned_labels