diff --git a/backend/test.sh b/backend/test.sh index 1a648af8a2..4271c2560e 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -8,6 +8,7 @@ export ENCRYPTION_SECRET="omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c pytest tests/unit/test_transcript_segment.py -v pytest tests/unit/test_text_similarity.py -v +pytest tests/unit/test_text_containment.py -v pytest tests/unit/test_speaker_sample.py -v pytest tests/unit/test_speaker_sample_migration.py -v pytest tests/unit/test_users_add_sample_transaction.py -v diff --git a/backend/tests/unit/test_speaker_sample.py b/backend/tests/unit/test_speaker_sample.py index 6ab9612107..38e12f3b2c 100644 --- a/backend/tests/unit/test_speaker_sample.py +++ b/backend/tests/unit/test_speaker_sample.py @@ -31,9 +31,7 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript is None assert is_valid is False @@ -52,9 +50,7 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript is None assert is_valid is False @@ -72,9 +68,7 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript is None assert is_valid is False @@ -119,9 +113,7 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript is None assert is_valid is False @@ -137,11 +129,11 @@ def test_verify_and_transcribe_sample_text_mismatch(monkeypatch): def fake_deepgram(*_args, **_kwargs): return words - def fake_similarity(_text1, _text2): + def fake_containment(_text1, _text2): return 0.5 monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - monkeypatch.setattr(speaker_sample, "compute_text_similarity", fake_similarity) + monkeypatch.setattr(speaker_sample, "compute_text_containment", fake_containment) transcript, is_valid, reason = asyncio.run( speaker_sample.verify_and_transcribe_sample( @@ -151,7 +143,7 @@ def fake_similarity(_text1, _text2): assert transcript == "good morning thanks for coming" assert is_valid is False - assert reason == "text_mismatch: similarity=0.50" + assert reason == "text_mismatch: containment=0.50" def test_verify_and_transcribe_sample_text_mismatch_just_below(monkeypatch): @@ -163,21 +155,19 @@ def test_verify_and_transcribe_sample_text_mismatch_just_below(monkeypatch): def fake_deepgram(*_args, **_kwargs): return words - def fake_similarity(_text1, _text2): - return 0.59 + def fake_containment(_text1, _text2): + return 0.89 monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - monkeypatch.setattr(speaker_sample, "compute_text_similarity", fake_similarity) + monkeypatch.setattr(speaker_sample, "compute_text_containment", fake_containment) transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample( - b"audio", 16000, expected_text="galaxy salsa party" - ) + speaker_sample.verify_and_transcribe_sample(b"audio", 16000, expected_text="galaxy salsa party") ) assert transcript == "galaxy salsa makes the party loud" assert is_valid is False - assert reason == "text_mismatch: similarity=0.59" + assert reason == "text_mismatch: containment=0.89" def test_verify_and_transcribe_sample_success(monkeypatch): @@ -186,19 +176,39 @@ def test_verify_and_transcribe_sample_success(monkeypatch): def fake_deepgram(*_args, **_kwargs): return words - def fake_similarity(_text1, _text2): - return 0.9 + def fake_containment(_text1, _text2): + return 0.95 + + monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) + monkeypatch.setattr(speaker_sample, "compute_text_containment", fake_containment) + + transcript, is_valid, reason = asyncio.run( + speaker_sample.verify_and_transcribe_sample(b"audio", 16000, expected_text="thanks for joining the meeting") + ) + + assert transcript == "thanks for joining the meeting" + assert is_valid is True + assert reason == "ok" + + +def test_verify_and_transcribe_sample_containment_real_function(monkeypatch): + words = _make_words( + ["orbiting", "satellites", "drift", "above", "quietly"], + speakers=["SPEAKER_00"] * 5, + ) + + def fake_deepgram(*_args, **_kwargs): + return words monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - monkeypatch.setattr(speaker_sample, "compute_text_similarity", fake_similarity) transcript, is_valid, reason = asyncio.run( speaker_sample.verify_and_transcribe_sample( - b"audio", 16000, expected_text="thanks for joining the meeting" + b"audio", 16000, expected_text="today orbiting satellites drift above quietly" ) ) - assert transcript == "thanks for joining the meeting" + assert transcript == "orbiting satellites drift above quietly" assert is_valid is True assert reason == "ok" @@ -214,9 +224,7 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript == "party on planet pizza night" assert is_valid is True @@ -244,16 +252,14 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript == " ".join(texts) assert is_valid is True assert reason == "ok" -def test_verify_and_transcribe_sample_similarity_boundary(monkeypatch): +def test_verify_and_transcribe_sample_containment_boundary(monkeypatch): words = _make_words( ["space", "pirates", "sail", "the", "neon", "seas"], speakers=["SPEAKER_00"] * 6, @@ -262,16 +268,14 @@ def test_verify_and_transcribe_sample_similarity_boundary(monkeypatch): def fake_deepgram(*_args, **_kwargs): return words - def fake_similarity(_text1, _text2): - return 0.6 + def fake_containment(_text1, _text2): + return 0.9 monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - monkeypatch.setattr(speaker_sample, "compute_text_similarity", fake_similarity) + monkeypatch.setattr(speaker_sample, "compute_text_containment", fake_containment) transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample( - b"audio", 16000, expected_text="space pirates sail neon seas" - ) + speaker_sample.verify_and_transcribe_sample(b"audio", 16000, expected_text="space pirates sail neon seas") ) assert transcript == "space pirates sail the neon seas" @@ -287,9 +291,7 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript == "just a solo astronaut report" assert is_valid is True @@ -307,9 +309,7 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript == "blank speaker tag shows up" assert is_valid is True @@ -326,10 +326,10 @@ def fake_deepgram(*_args, **_kwargs): return words def fail_similarity(*_args, **_kwargs): - raise AssertionError("compute_text_similarity should not be called") + raise AssertionError("compute_text_containment should not be called") monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - monkeypatch.setattr(speaker_sample, "compute_text_similarity", fail_similarity) + monkeypatch.setattr(speaker_sample, "compute_text_containment", fail_similarity) transcript, is_valid, reason = asyncio.run( speaker_sample.verify_and_transcribe_sample(b"audio", 16000, expected_text="") @@ -350,10 +350,10 @@ def fake_deepgram(*_args, **_kwargs): return words def fail_similarity(*_args, **_kwargs): - raise AssertionError("compute_text_similarity should not be called") + raise AssertionError("compute_text_containment should not be called") monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - monkeypatch.setattr(speaker_sample, "compute_text_similarity", fail_similarity) + monkeypatch.setattr(speaker_sample, "compute_text_containment", fail_similarity) transcript, is_valid, reason = asyncio.run( speaker_sample.verify_and_transcribe_sample(b"audio", 16000, expected_text=None) @@ -370,9 +370,7 @@ def fake_deepgram(*_args, **_kwargs): monkeypatch.setattr(speaker_sample, "deepgram_prerecorded_from_bytes", fake_deepgram) - transcript, is_valid, reason = asyncio.run( - speaker_sample.verify_and_transcribe_sample(b"audio", 16000) - ) + transcript, is_valid, reason = asyncio.run(speaker_sample.verify_and_transcribe_sample(b"audio", 16000)) assert transcript is None assert is_valid is False diff --git a/backend/tests/unit/test_text_containment.py b/backend/tests/unit/test_text_containment.py new file mode 100644 index 0000000000..464f2b67ef --- /dev/null +++ b/backend/tests/unit/test_text_containment.py @@ -0,0 +1,54 @@ +""" +Unit tests for compute_text_containment function. +Tests character trigram containment across multiple languages. +""" + +from utils.text_utils import compute_text_containment + + +class TestComputeTextContainment: + """Tests for the compute_text_containment function.""" + + def test_transcript_fully_contained(self): + transcript = "hello world nice day" + expected = "greetings hello world nice day everyone" + assert compute_text_containment(transcript, expected) == 1.0 + + def test_transcript_not_contained(self): + transcript = "hello world nice day" + expected = "greetings hello world pleasant evening" + containment = compute_text_containment(transcript, expected) + assert containment < 0.9 + + def test_empty_transcript(self): + assert compute_text_containment("", "hello") == 0.0 + + def test_short_transcript_contained(self): + assert compute_text_containment("hi", "oh hi there") == 1.0 + + def test_short_transcript_not_contained(self): + assert compute_text_containment("hi", "hello there") == 0.0 + + def test_case_and_whitespace_normalization(self): + transcript = "Hello World" + expected = "greetings hello world everyone" + assert compute_text_containment(transcript, expected) == 1.0 + + def test_chinese_contained(self): + transcript = "你好世界" + expected = "今天你好世界朋友" + assert compute_text_containment(transcript, expected) == 1.0 + + def test_thai_contained(self): + transcript = "สวัสดีครับ" + expected = "วันนี้สวัสดีครับเพื่อนๆ" + assert compute_text_containment(transcript, expected) == 1.0 + + def test_expected_empty_returns_zero(self): + assert compute_text_containment("hello", "") == 0.0 + + def test_trigram_length_boundary_contained(self): + assert compute_text_containment("hey", "oh hey there") == 1.0 + + def test_trigram_length_boundary_not_contained(self): + assert compute_text_containment("hey", "oh he there") == 0.0 diff --git a/backend/utils/speaker_sample.py b/backend/utils/speaker_sample.py index 5a8408f4d7..2e516e0256 100644 --- a/backend/utils/speaker_sample.py +++ b/backend/utils/speaker_sample.py @@ -12,10 +12,10 @@ from utils.other.storage import delete_speech_profile_blob, download_speech_profile_bytes from utils.stt.pre_recorded import deepgram_prerecorded_from_bytes -from utils.text_utils import compute_text_similarity +from utils.text_utils import compute_text_containment MIN_WORDS = 5 -MIN_SIMILARITY = 0.6 +MIN_CONTAINMENT = 0.9 MIN_DOMINANT_SPEAKER_RATIO = 0.7 @@ -30,7 +30,7 @@ async def verify_and_transcribe_sample( Checks: 1. Transcription has at least MIN_WORDS words 2. Dominant speaker accounts for >= MIN_DOMINANT_SPEAKER_RATIO of words (via diarization) - 3. Transcribed text has >= MIN_SIMILARITY with expected text (if provided) + 3. Transcribed text has >= MIN_CONTAINMENT containment in expected text (if provided) Args: audio_bytes: WAV format audio bytes @@ -66,9 +66,9 @@ async def verify_and_transcribe_sample( transcript = ' '.join(w.get('text', '') for w in words) if expected_text: - similarity = compute_text_similarity(transcript, expected_text) - if similarity < MIN_SIMILARITY: - return transcript, False, f"text_mismatch: similarity={similarity:.2f}" + containment = compute_text_containment(transcript, expected_text) + if containment < MIN_CONTAINMENT: + return transcript, False, f"text_mismatch: containment={containment:.2f}" return transcript, True, "ok" diff --git a/backend/utils/text_utils.py b/backend/utils/text_utils.py index 6976bc5188..99c4549976 100644 --- a/backend/utils/text_utils.py +++ b/backend/utils/text_utils.py @@ -1,29 +1,57 @@ +def _normalize_text(text: str) -> str: + """Normalize text: lowercase and collapse whitespace.""" + return ' '.join(text.lower().split()) + + +def _get_trigrams(text: str) -> set: + """Get character trigrams from normalized text.""" + text = _normalize_text(text) + if len(text) < 3: + return {text} if text else set() + return {text[i : i + 3] for i in range(len(text) - 2)} + + def compute_text_similarity(text1: str, text2: str) -> float: """ Compute text similarity using character trigram Jaccard. - Language-agnostic: works for all languages including CJK (Chinese, Japanese, Korean). + Language-agnostic: works for all languages including CJK. + + Returns: + Similarity score 0.0 to 1.0 (1.0 = identical) + """ + trigrams1 = _get_trigrams(text1) + trigrams2 = _get_trigrams(text2) + + if not trigrams1 or not trigrams2: + return 0.0 + + return len(trigrams1 & trigrams2) / len(trigrams1 | trigrams2) + + +def compute_text_containment(transcript: str, expected: str) -> float: + """ + Compute containment of transcript trigrams within expected text. + Language-agnostic: works for all languages including CJK. Args: - text1: First text - text2: Second text + transcript: Transcript text to check for containment + expected: Expected text that should contain the transcript Returns: - Similarity score 0.0 to 1.0 (1.0 = identical) + Containment score 0.0 to 1.0 (1.0 = fully contained) """ + transcript_norm = _normalize_text(transcript) + expected_norm = _normalize_text(expected) - def get_trigrams(text: str) -> set: - # Normalize: lowercase and remove extra whitespace - text = ' '.join(text.lower().split()) - if len(text) < 3: - return {text} if text else set() - return {text[i : i + 3] for i in range(len(text) - 2)} + if not transcript_norm: + return 0.0 + if len(transcript_norm) < 3: + return 1.0 if transcript_norm in expected_norm else 0.0 - trigrams1 = get_trigrams(text1) - trigrams2 = get_trigrams(text2) + trigrams_transcript = _get_trigrams(transcript) + trigrams_expected = _get_trigrams(expected) - if not trigrams1 or not trigrams2: + if not trigrams_transcript: return 0.0 - intersection = trigrams1 & trigrams2 - union = trigrams1 | trigrams2 - return len(intersection) / len(union) + return len(trigrams_transcript & trigrams_expected) / len(trigrams_transcript)