diff --git a/src/speechbox/diarize.py b/src/speechbox/diarize.py index 7bbf03e..f5253ea 100644 --- a/src/speechbox/diarize.py +++ b/src/speechbox/diarize.py @@ -8,6 +8,8 @@ from transformers import pipeline from transformers.pipelines.audio_utils import ffmpeg_read +from .utils.diarize_utils import match_segments + class ASRDiarizationPipeline: def __init__( @@ -34,7 +36,7 @@ def from_pretrained( "automatic-speech-recognition", model=asr_model, chunk_length_s=chunk_length_s, - token=use_auth_token, # 08/25/2023: Changed argument from use_auth_token to token + token=use_auth_token, # 08/25/2023: Changed argument from use_auth_token to token **kwargs, ) diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token) @@ -43,18 +45,16 @@ def from_pretrained( def __call__( self, inputs: Union[np.ndarray, List[np.ndarray]], - group_by_speaker: bool = True, + iou_threshold: float = 0.0, **kwargs, - ): + ) -> list[dict]: """ Transcribe the audio sequence(s) given as inputs to text and label with speaker information. The input audio is first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke when'. The audio is then passed to the ASR pipeline, which returns utterance-level transcriptions and their corresponding timestamps. The speaker diarizer timestamps are aligned with the ASR transcription timestamps to give - speaker-labelled transcriptions. We cannot use the speaker diarization timestamps alone to partition the - transcriptions, as these timestamps may straddle across transcribed utterances from the ASR output. Thus, we - find the diarizer timestamps that are closest to the ASR timestamps and partition here. - + speaker-labelled transcriptions. We perform a best intersection over union (IoU) to select the best match between + the speaker diarizer segment and the ASR transcription segment. Args: inputs (`np.ndarray` or `bytes` or `str` or `dict`): The inputs is either : @@ -69,30 +69,35 @@ def __call__( np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in decoding (but used at inference to provide more context to the model). Only use `stride` with CTC models. - group_by_speaker (`bool`): - Whether to group consecutive utterances by one speaker into a single segment. If False, will return - transcriptions on a chunk-by-chunk basis. + iou_threshold (float): + The threshold under which an IoU is considere too low. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update additional asr or diarization configuration parameters - To update the asr configuration, use the prefix *asr_* for each configuration parameter. - To update the diarization configuration, use the prefix *diarization_* for each configuration parameter. - Added this support related to issue #25: 08/25/2023 - + Return: A list of transcriptions. Each list item corresponds to one chunk / segment of transcription, and is a dictionary with the following keys: - **text** (`str` ) -- The recognized text. - **speaker** (`str`) -- The associated speaker. - **timestamps** (`tuple`) -- The start and end time for the chunk / segment. + + Note: + If no match occur between the speaker diarizer segment and the ASR transcription segment, a `NO_SPEAKER` label + will be assign as we can't infer properly the speaker of the segment. """ kwargs_asr = { argument[len("asr_") :]: value for argument, value in kwargs.items() if argument.startswith("asr_") } kwargs_diarization = { - argument[len("diarization_") :]: value for argument, value in kwargs.items() if argument.startswith("diarization_") + argument[len("diarization_") :]: value + for argument, value in kwargs.items() + if argument.startswith("diarization_") } - + inputs, diarizer_inputs = self.preprocess(inputs) diarization = self.diarization_pipeline( @@ -100,72 +105,29 @@ def __call__( **kwargs_diarization, ) - segments = [] - for segment, track, label in diarization.itertracks(yield_label=True): - segments.append({'segment': {'start': segment.start, 'end': segment.end}, - 'track': track, - 'label': label}) - - # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...}) - # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...}) - new_segments = [] - prev_segment = cur_segment = segments[0] - - for i in range(1, len(segments)): - cur_segment = segments[i] - - # check if we have changed speaker ("label") - if cur_segment["label"] != prev_segment["label"] and i < len(segments): - # add the start/end times for the super-segment to the new list - new_segments.append( - { - "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}, - "speaker": prev_segment["label"], - } - ) - prev_segment = segments[i] - - # add the last segment(s) if there was no speaker change - new_segments.append( - { - "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]}, - "speaker": prev_segment["label"], - } - ) + dia_seg, dia_label = [], [] + for segment, _, label in diarization.itertracks(yield_label=True): + dia_seg.append([segment.start, segment.end]) + dia_label.append(label) + + assert ( + dia_seg + ), "The result from the diarization pipeline: `diarization_segments` is empty. No segments found from the diarization process." asr_out = self.asr_pipeline( {"array": inputs, "sampling_rate": self.sampling_rate}, return_timestamps=True, **kwargs_asr, ) - transcript = asr_out["chunks"] - - # get the end timestamps for each chunk from the ASR output - end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) - segmented_preds = [] - - # align the diarizer timestamps and the ASR timestamps - for segment in new_segments: - # get the diarizer end timestamp - end_time = segment["segment"]["end"] - # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here - upto_idx = np.argmin(np.abs(end_timestamps - end_time)) - - if group_by_speaker: - segmented_preds.append( - { - "speaker": segment["speaker"], - "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), - "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]), - } - ) - else: - for i in range(upto_idx + 1): - segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) + segmented_preds = asr_out["chunks"] + + dia_seg = np.array(dia_seg) + asr_seg = np.array([[*chunk["timestamp"]] for chunk in segmented_preds]) + + asr_labels = match_segments(dia_seg, dia_label, asr_seg, threshold=iou_threshold, no_match_label="NO_SPEAKER") - # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin) - transcript = transcript[upto_idx + 1 :] - end_timestamps = end_timestamps[upto_idx + 1 :] + for i, label in enumerate(asr_labels): + segmented_preds[i]["speaker"] = label return segmented_preds diff --git a/src/speechbox/utils/diarize_utils.py b/src/speechbox/utils/diarize_utils.py new file mode 100644 index 0000000..62a7932 --- /dev/null +++ b/src/speechbox/utils/diarize_utils.py @@ -0,0 +1,95 @@ +import numpy as np + + +def IoU(diarized_segments: np.ndarray, asr_segments: np.ndarray) -> np.ndarray: + """ + Calculates the Intersection over Union (IoU) between diarized_segments and asr_segments. + + Args: + ----------- + - diarized_segments (np.ndarray): An array representing N segments with shape (M, 2), where each row + contains the start and end times of a diarized segment. + - asr_segments (np.ndarray): An array representing M segments with shape (N, 2), where each row contains + the start and end times of an asr segment. + + Returns: + -------- + - np.ndarray: A 2D array of shape (N, M) representing the IoU between each pair of diarized and. + The value at position (i, j) in the array corresponds to the IoU between the asr segment i and the diarized segment j. + Values are in the range [0, 1], where 0 indicates no intersection and 1 indicates perfect overlap. + + Note: + - The IoU is calculated as the ratio of the intersection over the union of the time intervals. + - Segments with no overlap result in an IoU value of 0. + - Segments with overlap but no intersection (e.g., one segment completely contained within another) can + have an IoU greater than 0. + + Example: + ```python + diarized_segments = np.array([[0, 5], [3, 8], [6, 10]]) + asr_segments = np.array([[2, 6], [1, 4]]) + + IoU_values = IoU(diarized_segments, asr_segments) + print(IoU_values) + # Output + # [[0.5 0.5 0.] + # [0.6 0.14285714 0.]] + ``` + """ + # We measure intersection between each of the N asr_segments [Nx2] and each M of diarize_ segments [Mx2] + # The result is a NxM matrix. intersection <= 0 mean no intersection. + starts = np.maximum(asr_segments[:, 0, np.newaxis], diarized_segments[:, 0]) + ends = np.minimum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1]) + intersections = np.maximum(ends - starts, 0) + + # Union for segments without overlap will lead to invalid results but it does not matters + # as we opt them out eventually. + union = np.maximum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1]) - np.minimum( + asr_segments[:, 0, np.newaxis], diarized_segments[:, 0] + ) + + # Negative results are zeroed as they are invalid. + intersection_over_union = np.maximum(intersections / union, 0) + + return intersection_over_union + + +def match_segments( + diarized_segments: np.ndarray, + diarized_labels: list[str], + asr_segments: np.ndarray, + threshold: float = 0.0, + no_match_label: str = "NO_SPEAKER", +) -> np.ndarray: + """ + Perform segment matching between diarized segments and ASR (Automatic Speech Recognition) segments. + + Args: + ----- + - diarized_segments (np.ndarray): Array representing diarized speaker segments. + - diarized_labels (list[str]): List of labels corresponding to diarized_segments. + - asr_segments (np.ndarray): Array representing ASR speaker segments. + - threshold (float, optional): IoU (Intersection over Union) threshold for matching. Default is 0.0. + - no_match_label (str, optional): Label assigned when no matching segment is found. Default is "NO_SPEAKER". + + Returns: + -------- + - np.ndarray: Array of labels corresponding to the best-matched ASR segments for each diarized segment. + + Notes: + - The function calculates IoU between diarized segments and ASR segments and considers only segments with IoU above the threshold. + - If no matching segment is found, the specified `no_match_label` is assigned. + - The returned array represents the labels of the best-matched ASR segments for each diarized segment. + """ + iou_results = IoU(diarized_segments, asr_segments) + # Zero out iou below threshold. + iou_results[iou_results <= threshold] = 0.0 + # We create a no match label which value will be threshold + diarized_labels = [no_match_label] + diarized_labels + # If there is nothing above threshold, no_match_label will be assigned. + iou_results = np.hstack([threshold * np.ones((iou_results.shape[0], 1)), iou_results]) + # Will find argument with highest iou (if all zeroes, will assign first (no_match_label)). + best_match_idx = np.argmax(iou_results, axis=1) + assigned_labels = np.take(diarized_labels, best_match_idx) + + return assigned_labels