Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_add_messages_with_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def test_add_messages_with_indexing_basic():
related_term_index_settings=settings.related_term_index_settings,
)
settings.storage_provider = storage
transcript = await Transcript.create(settings, name_tag="test")
transcript = await Transcript.create(settings, name="test")

metadata1 = TranscriptMessageMeta(speaker="Alice")
metadata2 = TranscriptMessageMeta(speaker="Bob")
Expand Down Expand Up @@ -76,7 +76,7 @@ async def test_add_messages_with_indexing_batched():
related_term_index_settings=settings.related_term_index_settings,
)
settings.storage_provider = storage
transcript = await Transcript.create(settings, name_tag="test")
transcript = await Transcript.create(settings, name="test")

# Add first batch
batch1 = [
Expand Down Expand Up @@ -133,7 +133,7 @@ async def test_transaction_rollback_on_error():
related_term_index_settings=settings.related_term_index_settings,
)
settings.storage_provider = storage
transcript = await Transcript.create(settings, name_tag="test")
transcript = await Transcript.create(settings, name="test")

# Add some valid messages first
batch1 = [
Expand Down
57 changes: 20 additions & 37 deletions test/test_incremental_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_incremental_index_building():
related_term_index_settings=settings.related_term_index_settings,
)
settings.storage_provider = storage1
transcript1 = await Transcript.create(settings, name_tag="test")
transcript1 = await Transcript.create(settings, name="test")

# Add some messages
messages1 = [
Expand All @@ -56,18 +56,16 @@ async def test_incremental_index_building():
tags=["file1"],
),
]
for msg in messages1:
await transcript1.messages.append(msg)

# Add messages with indexing
print("Adding messages with indexing...")
result1 = await transcript1.add_messages_with_indexing(messages1)

msg_count1 = await transcript1.messages.size()
print(f"Added {msg_count1} messages")

# Build index
print("Building index for first time...")
await transcript1.build_index()
print(f"Created {result1.semrefs_added} semantic refs")

ref_count1 = await transcript1.semantic_refs.size()
print(f"Created {ref_count1} semantic refs")

# Close first connection
await storage1.close()
Expand All @@ -84,7 +82,7 @@ async def test_incremental_index_building():
related_term_index_settings=settings2.related_term_index_settings,
)
settings2.storage_provider = storage2
transcript2 = await Transcript.create(settings2, name_tag="test")
transcript2 = await Transcript.create(settings2, name="test")

# Verify existing messages are there
msg_count_before = await transcript2.messages.size()
Expand All @@ -104,36 +102,23 @@ async def test_incremental_index_building():
tags=["file2"],
),
]
for msg in messages2:
await transcript2.messages.append(msg)

# Add messages with indexing
print("Adding more messages with indexing...")
result2 = await transcript2.add_messages_with_indexing(messages2)

msg_count2 = await transcript2.messages.size()
print(f"Now have {msg_count2} messages total")
assert msg_count2 == msg_count_before + len(messages2)

# Try to rebuild index - this should work incrementally
print("Rebuilding index...")
try:
await transcript2.build_index()
print("SUCCESS: Index rebuilt!")

ref_count2 = await transcript2.semantic_refs.size()
print(f"Now have {ref_count2} semantic refs (was {ref_count1})")

# We should have more refs now
assert (
ref_count2 >= ref_count1
), "Should have at least as many refs as before"

except Exception as e:
print(f"FAILED: {e}")
import traceback
print("SUCCESS: Messages added with incremental indexing!")
ref_count2 = await transcript2.semantic_refs.size()
print(f"Now have {ref_count2} semantic refs (was {ref_count1})")

traceback.print_exc()
pytest.fail(f"Index building failed: {e}")
# We should have more refs now
assert ref_count2 >= ref_count1, "Should have at least as many refs as before"

finally:
await storage2.close()
await storage2.close()


@pytest.mark.asyncio
Expand Down Expand Up @@ -164,8 +149,7 @@ async def test_incremental_index_with_vtt_files():
msg_count1 = await transcript1.messages.size()
print(f"Imported {msg_count1} messages from Confuse-A-Cat.vtt")

# Build index
await transcript1.build_index()
# Indexing already done by add_messages_with_indexing() in ingest
ref_count1 = await transcript1.semantic_refs.size()
print(f"Built index with {ref_count1} semantic refs")

Expand All @@ -190,9 +174,8 @@ async def test_incremental_index_with_vtt_files():
print(f"Now have {msg_count2} messages total")
assert msg_count2 > msg_count1, "Should have added more messages"

# Rebuild index incrementally
print("Rebuilding index incrementally...")
await transcript2.build_index()
# Indexing already done incrementally by add_messages_with_indexing()
print("Index built incrementally during ingestion")
ref_count2 = await transcript2.semantic_refs.size()
print(f"Now have {ref_count2} semantic refs (was {ref_count1})")

Expand Down
8 changes: 4 additions & 4 deletions test/test_podcast_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ async def test_podcast_add_messages_with_indexing():
related_term_index_settings=settings.related_term_index_settings,
)
settings.storage_provider = storage
podcast = await Podcast.create(settings, name_tag="test")
podcast = await Podcast.create(settings, name="test")

metadata1 = PodcastMessageMeta(speaker="Host", listeners=["Guest"])
metadata2 = PodcastMessageMeta(speaker="Guest", listeners=["Host"])
metadata1 = PodcastMessageMeta(speaker="Host", recipients=["Guest"])
metadata2 = PodcastMessageMeta(speaker="Guest", recipients=["Host"])

messages = [
PodcastMessage(text_chunks=["Welcome to the podcast!"], metadata=metadata1),
Expand Down Expand Up @@ -68,7 +68,7 @@ async def test_podcast_add_messages_batched():
related_term_index_settings=settings.related_term_index_settings,
)
settings.storage_provider = storage
podcast = await Podcast.create(settings, name_tag="test")
podcast = await Podcast.create(settings, name="test")

# Add first batch
metadata1 = PodcastMessageMeta(speaker="Host")
Expand Down
7 changes: 3 additions & 4 deletions test/test_podcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import os
import pytest
from datetime import timezone

from fixtures import needs_auth, temp_dir, embedding_model # type: ignore # Yes they are used!

Expand All @@ -25,7 +26,7 @@ async def test_ingest_podcast(
"testdata/FakePodcast.txt",
settings,
None,
Datetime.now(),
Datetime.now(timezone.utc), # Use timezone-aware datetime
3.0,
)

Expand All @@ -34,9 +35,7 @@ async def test_ingest_podcast(
assert len(pod.tags) > 0
assert await pod.messages.size() > 0

# Build the index
await pod.build_index()
# Verify the semantic refs were built by checking they exist
# Verify the semantic refs exist
assert pod.semantic_refs is not None

# Write the podcast to files
Expand Down
8 changes: 4 additions & 4 deletions test/test_storage_providers_unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ async def test_timestamp_index_behavior_parity(
time_index = await storage_provider.get_timestamp_index()

# Test empty lookup_range interface
start_time = Datetime.fromisoformat("2024-01-01T00:00:00")
end_time = Datetime.fromisoformat("2024-01-02T00:00:00")
start_time = Datetime.fromisoformat("2024-01-01T00:00:00Z")
end_time = Datetime.fromisoformat("2024-01-02T00:00:00Z")
date_range = DateRange(start=start_time, end=end_time)

empty_results = await time_index.lookup_range(date_range)
Expand Down Expand Up @@ -478,8 +478,8 @@ async def test_timestamp_index_range_queries(
timestamp_index = await storage_provider.get_timestamp_index()

# Test basic interface - empty range query
start_time = Datetime.fromisoformat("2024-01-01T00:00:00")
end_time = Datetime.fromisoformat("2024-01-02T00:00:00")
start_time = Datetime.fromisoformat("2024-01-01T00:00:00Z")
end_time = Datetime.fromisoformat("2024-01-02T00:00:00Z")
date_range = DateRange(start=start_time, end=end_time)

empty_results = await timestamp_index.lookup_range(date_range)
Expand Down
81 changes: 51 additions & 30 deletions test/test_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import os
import tempfile
from datetime import timedelta
from typing import AsyncGenerator

from typeagent.transcripts.transcript_ingest import (
Expand All @@ -18,6 +19,10 @@
TranscriptMessage,
TranscriptMessageMeta,
)
from typeagent.knowpro.universal_message import (
UNIX_EPOCH,
format_timestamp_utc,
)
from typeagent.knowpro.convsettings import ConversationSettings
from typeagent.knowpro.interfaces import Datetime
from typeagent.aitools.embeddings import AsyncEmbeddingModel
Expand Down Expand Up @@ -130,12 +135,19 @@ async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings
if not text.strip():
continue

# Calculate timestamp from WebVTT start time
offset_seconds = webvtt_timestamp_to_seconds(caption.start)
timestamp = format_timestamp_utc(
UNIX_EPOCH + timedelta(seconds=offset_seconds)
)

metadata = TranscriptMessageMeta(
speaker=speaker,
start_time=caption.start,
end_time=caption.end,
recipients=[],
)
message = TranscriptMessage(
text_chunks=[text], metadata=metadata, timestamp=timestamp
)
message = TranscriptMessage(text_chunks=[text], metadata=metadata)
messages_list.append(message)

# Create in-memory collections
Expand All @@ -148,13 +160,13 @@ async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings
# Create transcript with in-memory storage
transcript = await Transcript.create(
settings,
name_tag="Test-Confuse-A-Cat",
messages=msg_coll,
semantic_refs=semref_coll,
semantic_ref_index=semref_index,
name="Test-Confuse-A-Cat",
tags=["Test-Confuse-A-Cat", "vtt-transcript"],
)

# Add messages to the transcript's collections
await transcript.messages.extend(messages_list)

# Verify the transcript was created correctly
assert isinstance(transcript, Transcript)
assert transcript.name_tag == "Test-Confuse-A-Cat"
Expand All @@ -177,35 +189,39 @@ async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings
assert len(first_message.text_chunks) > 0
assert first_message.text_chunks[0].strip() != ""

# Verify metadata has timestamp information
assert first_message.metadata.start_time is not None
assert first_message.metadata.end_time is not None
# Verify message has timestamp
assert first_message.timestamp is not None
assert first_message.timestamp.endswith("Z") # Should be UTC


def test_transcript_message_creation():
"""Test creating transcript messages manually."""
# Create a transcript message
metadata = TranscriptMessageMeta(
speaker="Test Speaker", start_time="00:00:10.000", end_time="00:00:15.000"
)
# Create a transcript message with timestamp
timestamp = format_timestamp_utc(UNIX_EPOCH + timedelta(seconds=10))
metadata = TranscriptMessageMeta(speaker="Test Speaker", recipients=[])

message = TranscriptMessage(
text_chunks=["This is a test message."], metadata=metadata, tags=["test"]
text_chunks=["This is a test message."],
metadata=metadata,
tags=["test"],
timestamp=timestamp,
)

# Test serialization
serialized = message.serialize()
assert serialized["textChunks"] == ["This is a test message."]
assert serialized["metadata"]["speaker"] == "Test Speaker"
assert serialized["metadata"]["start_time"] == "00:00:10.000"
assert serialized["metadata"]["recipients"] == []
assert serialized["tags"] == ["test"]
assert serialized["timestamp"] == timestamp

# Test deserialization
deserialized = TranscriptMessage.deserialize(serialized)
assert deserialized.text_chunks == ["This is a test message."]
assert deserialized.metadata.speaker == "Test Speaker"
assert deserialized.metadata.start_time == "00:00:10.000"
assert deserialized.metadata.recipients == []
assert deserialized.tags == ["test"]
assert deserialized.timestamp == timestamp


@pytest.mark.asyncio
Expand All @@ -218,7 +234,7 @@ async def test_transcript_creation():
settings = ConversationSettings(embedding_model)

transcript = await Transcript.create(
settings=settings, name_tag="Test Transcript", tags=["test", "empty"]
settings=settings, name="Test Transcript", tags=["test", "empty"]
)

assert transcript.name_tag == "Test Transcript"
Expand Down Expand Up @@ -270,12 +286,17 @@ async def test_transcript_knowledge_extraction_slow(
speaker = getattr(caption, "voice", None)
text = caption.text.strip()

# Calculate timestamp from WebVTT start time
offset_seconds = webvtt_timestamp_to_seconds(caption.start)
timestamp = format_timestamp_utc(UNIX_EPOCH + timedelta(seconds=offset_seconds))

metadata = TranscriptMessageMeta(
speaker=speaker,
start_time=caption.start,
end_time=caption.end,
recipients=[],
)
message = TranscriptMessage(
text_chunks=[text], metadata=metadata, timestamp=timestamp
)
message = TranscriptMessage(text_chunks=[text], metadata=metadata)
messages_list.append(message)

# Create in-memory collections
Expand All @@ -288,27 +309,27 @@ async def test_transcript_knowledge_extraction_slow(
# Create transcript with in-memory storage
transcript = await Transcript.create(
settings,
name_tag="Parrot-Test",
messages=msg_coll,
semantic_refs=semref_coll,
semantic_ref_index=semref_index,
name="Parrot-Test",
tags=["test", "parrot"],
)

# Verify we have messages
assert await transcript.messages.size() == len(messages_list)
assert len(messages_list) >= 3, "Need at least 3 messages for testing"

# Enable knowledge extraction
settings.semantic_ref_index_settings.auto_extract_knowledge = True
settings.semantic_ref_index_settings.batch_size = 10

# Build index (this should extract knowledge)
await transcript.build_index()
# Add messages with indexing (this should extract knowledge)
result = await transcript.add_messages_with_indexing(messages_list)

# Verify messages and semantic refs were created
assert await transcript.messages.size() == len(messages_list)
assert result.messages_added == len(messages_list)
assert result.semrefs_added > 0, "Should have extracted some semantic references"

# Verify semantic refs were created
semref_count = await transcript.semantic_refs.size()
assert semref_count > 0, "Should have extracted some semantic references"
assert semref_count > 0, "Should have semantic refs"

# Verify we have different types of knowledge
knowledge_types = set()
Expand Down
File renamed without changes.
Loading