diff --git a/src/typeagent/mcp/server.py b/src/typeagent/mcp/server.py index 0c106ec..0ff612f 100644 --- a/src/typeagent/mcp/server.py +++ b/src/typeagent/mcp/server.py @@ -6,6 +6,7 @@ import argparse from dataclasses import dataclass +import os import time from typing import Any @@ -28,6 +29,9 @@ from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex from typeagent.storage.utils import create_storage_provider +# Example podcast index path for documentation and error messages +_EXAMPLE_PODCAST_INDEX = "tests/testdata/Episode_53_AdrianTchaikovsky_index" + class MCPTypeChatModel(typechat.TypeChatLanguageModel): """TypeChat language model that uses MCP sampling API instead of direct API calls.""" @@ -142,7 +146,9 @@ async def make_context( entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None ) - query_context = await load_podcast_index_or_database(settings, dbname) + query_context = await load_podcast_database_or_index( + settings, dbname, _podcast_index + ) # Use MCP-based model instead of one that requires API keys model = MCPTypeChatModel(session) @@ -161,24 +167,33 @@ async def make_context( return context -async def load_podcast_index_or_database( +async def load_podcast_database_or_index( settings: ConversationSettings, dbname: str | None = None, + podcast_index: str | None = None, ) -> query.QueryEvalContext[podcast.PodcastMessage, Any]: - if dbname is None: - conversation = await podcast.Podcast.read_from_file( - "tests/testdata/Episode_53_AdrianTchaikovsky_index", settings - ) - else: + if dbname is not None: + # Load from SQLite database conversation = await podcast.Podcast.create(settings) + elif podcast_index is not None: + # Load from JSON index files + conversation = await podcast.Podcast.read_from_file(podcast_index, settings) + else: + raise ValueError( + "Either --database or --podcast-index must be specified. " + "Use --podcast-index to specify the path to podcast index files " + f"(e.g., '{_EXAMPLE_PODCAST_INDEX}')." + ) return query.QueryEvalContext(conversation) # Create an MCP server mcp = FastMCP("typagent") -# Global variable to store database path (set via command-line argument) +# Global variables to store command-line arguments +# (no other straightforward way to pass to tool handlers) _dbname: str | None = None +_podcast_index: str | None = None @dataclass @@ -245,12 +260,49 @@ async def query_conversation( "--database", type=str, default=None, - help="Path to the SQLite database file (default: load from JSON file)", + help="Path to a SQLite database file with pre-indexed podcast data", + ) + parser.add_argument( + "-p", + "--podcast-index", + type=str, + default=None, + help="Path to podcast index files (excluding '_data.json' suffix), " + f"e.g., '{_EXAMPLE_PODCAST_INDEX}'", ) args = parser.parse_args() - # Store database path in global variable (no other straightforward way to pass to tool) + # Validate arguments + if args.database is None and args.podcast_index is None: + parser.error( + "Either --database or --podcast-index is required.\n" + "Example: python -m typeagent.mcp.server " + f"--podcast-index {_EXAMPLE_PODCAST_INDEX}" + ) + + if args.database is not None and args.podcast_index is not None: + parser.error("Cannot specify both --database and --podcast-index") + + # Validate file existence + if args.database is not None and not os.path.exists(args.database): + parser.error( + f"Database file not found: {args.database}\n" + "Please provide a valid path to an existing SQLite database." + ) + + if args.podcast_index is not None: + data_file = args.podcast_index + "_data.json" + if not os.path.exists(data_file): + parser.error( + f"Podcast index file not found: {data_file}\n" + "Please provide a valid path to podcast index files " + "(without the '_data.json' suffix).\n" + f"Example: {_EXAMPLE_PODCAST_INDEX}" + ) + + # Store in global variables for tool handlers _dbname = args.database + _podcast_index = args.podcast_index # Use stdio transport for simplicity mcp.run(transport="stdio") diff --git a/tests/test/conftest.py b/tests/conftest.py similarity index 89% rename from tests/test/conftest.py rename to tests/conftest.py index e5f326b..3533e23 100644 --- a/tests/test/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from collections.abc import AsyncGenerator, Iterator +from collections.abc import AsyncGenerator, Callable, Iterator import os +from pathlib import Path import tempfile from typing import Any @@ -43,6 +44,37 @@ MemorySemanticRefCollection, ) +# --- Testdata path utilities --- +# Locate the tests directory relative to this file +_TESTS_DIR = Path(__file__).resolve().parent # tests/ +_TESTDATA_DIR = _TESTS_DIR / "testdata" +_REPO_ROOT = _TESTS_DIR.parent + + +def get_testdata_path(filename: str) -> str: + """Return absolute path to a file in tests/testdata/.""" + return str(_TESTDATA_DIR / filename) + + +def get_repo_root() -> Path: + """Return the repository root path.""" + return _REPO_ROOT + + +def has_testdata_file(filename: str) -> bool: + """Check if a testdata file exists (for use in skipif conditions).""" + return (_TESTDATA_DIR / filename).exists() + + +# Commonly used test files as constants +CONFUSE_A_CAT_VTT = get_testdata_path("Confuse-A-Cat.vtt") +PARROT_SKETCH_VTT = get_testdata_path("Parrot_Sketch.vtt") +FAKE_PODCAST_TXT = get_testdata_path("FakePodcast.txt") +EPISODE_53_INDEX = get_testdata_path("Episode_53_AdrianTchaikovsky_index") +EPISODE_53_TRANSCRIPT = get_testdata_path("Episode_53_AdrianTchaikovsky.txt") +EPISODE_53_ANSWERS = get_testdata_path("Episode_53_Answer_results.json") +EPISODE_53_SEARCH = get_testdata_path("Episode_53_Search_results.json") + @pytest.fixture(scope="session") def needs_auth() -> None: @@ -63,6 +95,17 @@ def embedding_model() -> AsyncEmbeddingModel: return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) +@pytest.fixture(scope="session") +def testdata_path() -> Callable[[str], str]: + """Fixture returning a function to get absolute paths to testdata files. + + Usage: + def test_something(testdata_path): + path = testdata_path("Confuse-A-Cat.vtt") + """ + return get_testdata_path + + @pytest.fixture def temp_dir() -> Iterator[str]: with tempfile.TemporaryDirectory() as dir: diff --git a/tests/test/test_add_messages_with_indexing.py b/tests/test_add_messages_with_indexing.py similarity index 100% rename from tests/test/test_add_messages_with_indexing.py rename to tests/test_add_messages_with_indexing.py diff --git a/tests/test/test_auth.py b/tests/test_auth.py similarity index 100% rename from tests/test/test_auth.py rename to tests/test_auth.py diff --git a/tests/test/test_collections.py b/tests/test_collections.py similarity index 100% rename from tests/test/test_collections.py rename to tests/test_collections.py diff --git a/tests/test/test_conversation_metadata.py b/tests/test_conversation_metadata.py similarity index 100% rename from tests/test/test_conversation_metadata.py rename to tests/test_conversation_metadata.py diff --git a/tests/test/test_demo.py b/tests/test_demo.py similarity index 95% rename from tests/test/test_demo.py rename to tests/test_demo.py index 36acca3..599f006 100644 --- a/tests/test/test_demo.py +++ b/tests/test_demo.py @@ -3,7 +3,6 @@ import argparse import asyncio -import os import textwrap import time @@ -12,23 +11,21 @@ from typeagent.knowpro.interfaces import ScoredSemanticRefOrdinal from typeagent.podcasts import podcast -tests_dir = os.path.dirname(__file__) -root_dir = os.path.dirname(tests_dir) -DEFAULT_FILE = os.path.join(root_dir, "testdata", "Episode_53_AdrianTchaikovsky_index") +from conftest import EPISODE_53_INDEX parser = argparse.ArgumentParser() parser.add_argument( "filename", nargs="?", type=str, - default=DEFAULT_FILE, + default=EPISODE_53_INDEX, ) def test_main(really_needs_auth: None): # auth is needed because we use embeddings. # TODO: Only use the embeddings loaded from the file and cached. - asyncio.run(main(DEFAULT_FILE)) + asyncio.run(main(EPISODE_53_INDEX)) async def main(filename_prefix: str): diff --git a/tests/test/test_embedding_consistency.py b/tests/test_embedding_consistency.py similarity index 100% rename from tests/test/test_embedding_consistency.py rename to tests/test_embedding_consistency.py diff --git a/tests/test/test_embeddings.py b/tests/test_embeddings.py similarity index 100% rename from tests/test/test_embeddings.py rename to tests/test_embeddings.py diff --git a/tests/test/test_factory.py b/tests/test_factory.py similarity index 100% rename from tests/test/test_factory.py rename to tests/test_factory.py diff --git a/tests/test/test_incremental_index.py b/tests/test_incremental_index.py similarity index 96% rename from tests/test/test_incremental_index.py rename to tests/test_incremental_index.py index 077db4b..12f706a 100644 --- a/tests/test/test_incremental_index.py +++ b/tests/test_incremental_index.py @@ -18,8 +18,7 @@ ) from typeagent.transcripts.transcript_ingest import ingest_vtt_transcript -tests_dir = os.path.dirname(__file__) -root_dir = os.path.dirname(tests_dir) +from conftest import CONFUSE_A_CAT_VTT, PARROT_SKETCH_VTT @pytest.mark.asyncio @@ -144,9 +143,8 @@ async def test_incremental_index_with_vtt_files(): # First VTT file ingestion print("\n=== Import first VTT file ===") # Import the first transcript - DEFAULT_FILE = os.path.join(root_dir, "testdata", "Confuse-A-Cat.vtt") transcript1 = await ingest_vtt_transcript( - DEFAULT_FILE, + CONFUSE_A_CAT_VTT, settings, dbname=db_path, ) @@ -169,9 +167,8 @@ async def test_incremental_index_with_vtt_files(): settings2.semantic_ref_index_settings.auto_extract_knowledge = False # Ingest the second transcript - DEFAULT_FILE = os.path.join(root_dir, "testdata", "Parrot_Sketch.vtt") transcript2 = await ingest_vtt_transcript( - DEFAULT_FILE, + PARROT_SKETCH_VTT, settings2, dbname=db_path, ) diff --git a/tests/test/test_interfaces.py b/tests/test_interfaces.py similarity index 100% rename from tests/test/test_interfaces.py rename to tests/test_interfaces.py diff --git a/tests/test/test_knowledge.py b/tests/test_knowledge.py similarity index 100% rename from tests/test/test_knowledge.py rename to tests/test_knowledge.py diff --git a/tests/test/test_kplib.py b/tests/test_kplib.py similarity index 100% rename from tests/test/test_kplib.py rename to tests/test_kplib.py diff --git a/tests/test/test_mcp_server.py b/tests/test_mcp_server.py similarity index 94% rename from tests/test/test_mcp_server.py rename to tests/test_mcp_server.py index fc825ab..03fd0e6 100644 --- a/tests/test/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -14,17 +14,21 @@ from mcp.shared.context import RequestContext from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent +from conftest import EPISODE_53_INDEX + @pytest.fixture def server_params() -> StdioServerParameters: - """Create MCP server parameters with minimal environment.""" - env = {} + """Create MCP server parameters with environment inherited from parent process.""" + # Start with the full environment - subprocess needs PATH, PYTHONPATH, etc. + env = dict(os.environ) + # Coverage support if "COVERAGE_PROCESS_START" in os.environ: env["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_PROCESS_START"] return StdioServerParameters( command=sys.executable, - args=["-m", "typeagent.mcp.server"], + args=["-m", "typeagent.mcp.server", "--podcast-index", EPISODE_53_INDEX], env=env, ) diff --git a/tests/test/test_message_text_index_population.py b/tests/test_message_text_index_population.py similarity index 100% rename from tests/test/test_message_text_index_population.py rename to tests/test_message_text_index_population.py diff --git a/tests/test/test_message_text_index_serialization.py b/tests/test_message_text_index_serialization.py similarity index 100% rename from tests/test/test_message_text_index_serialization.py rename to tests/test_message_text_index_serialization.py diff --git a/tests/test/test_messageindex.py b/tests/test_messageindex.py similarity index 100% rename from tests/test/test_messageindex.py rename to tests/test_messageindex.py diff --git a/tests/test/test_online.py b/tests/test_online.py similarity index 100% rename from tests/test/test_online.py rename to tests/test_online.py diff --git a/tests/test/test_podcast_incremental.py b/tests/test_podcast_incremental.py similarity index 100% rename from tests/test/test_podcast_incremental.py rename to tests/test_podcast_incremental.py diff --git a/tests/test/test_podcasts.py b/tests/test_podcasts.py similarity index 95% rename from tests/test/test_podcasts.py rename to tests/test_podcasts.py index 2ac9b3a..6f901a7 100644 --- a/tests/test/test_podcasts.py +++ b/tests/test_podcasts.py @@ -13,9 +13,7 @@ from typeagent.podcasts import podcast_ingest from typeagent.podcasts.podcast import Podcast -tests_dir = os.path.dirname(__file__) -root_dir = os.path.dirname(tests_dir) -DEFAULT_FILE = os.path.join(root_dir, "testdata", "FakePodcast.txt") +from conftest import FAKE_PODCAST_TXT @pytest.mark.asyncio @@ -25,7 +23,7 @@ async def test_ingest_podcast( # Import the podcast settings = ConversationSettings(embedding_model) pod = await podcast_ingest.ingest_podcast( - DEFAULT_FILE, + FAKE_PODCAST_TXT, settings, None, Datetime.now(timezone.utc), # Use timezone-aware datetime diff --git a/tests/test/test_property_index_population.py b/tests/test_property_index_population.py similarity index 100% rename from tests/test/test_property_index_population.py rename to tests/test_property_index_population.py diff --git a/tests/test/test_propindex.py b/tests/test_propindex.py similarity index 100% rename from tests/test/test_propindex.py rename to tests/test_propindex.py diff --git a/tests/test/test_query.py b/tests/test_query.py similarity index 100% rename from tests/test/test_query.py rename to tests/test_query.py diff --git a/tests/test/test_query_method.py b/tests/test_query_method.py similarity index 100% rename from tests/test/test_query_method.py rename to tests/test_query_method.py diff --git a/tests/test/test_related_terms_fast.py b/tests/test_related_terms_fast.py similarity index 100% rename from tests/test/test_related_terms_fast.py rename to tests/test_related_terms_fast.py diff --git a/tests/test/test_related_terms_index_population.py b/tests/test_related_terms_index_population.py similarity index 100% rename from tests/test/test_related_terms_index_population.py rename to tests/test_related_terms_index_population.py diff --git a/tests/test/test_reltermsindex.py b/tests/test_reltermsindex.py similarity index 100% rename from tests/test/test_reltermsindex.py rename to tests/test_reltermsindex.py diff --git a/tests/test/test_searchlib.py b/tests/test_searchlib.py similarity index 100% rename from tests/test/test_searchlib.py rename to tests/test_searchlib.py diff --git a/tests/test/test_secindex.py b/tests/test_secindex.py similarity index 100% rename from tests/test/test_secindex.py rename to tests/test_secindex.py diff --git a/tests/test/test_secindex_storage_integration.py b/tests/test_secindex_storage_integration.py similarity index 100% rename from tests/test/test_secindex_storage_integration.py rename to tests/test_secindex_storage_integration.py diff --git a/tests/test/test_semrefindex.py b/tests/test_semrefindex.py similarity index 100% rename from tests/test/test_semrefindex.py rename to tests/test_semrefindex.py diff --git a/tests/test/test_serialization.py b/tests/test_serialization.py similarity index 100% rename from tests/test/test_serialization.py rename to tests/test_serialization.py diff --git a/tests/test/test_sqlite_indexes.py b/tests/test_sqlite_indexes.py similarity index 100% rename from tests/test/test_sqlite_indexes.py rename to tests/test_sqlite_indexes.py diff --git a/tests/test/test_sqlitestore.py b/tests/test_sqlitestore.py similarity index 100% rename from tests/test/test_sqlitestore.py rename to tests/test_sqlitestore.py diff --git a/tests/test/test_storage_providers_unified.py b/tests/test_storage_providers_unified.py similarity index 100% rename from tests/test/test_storage_providers_unified.py rename to tests/test_storage_providers_unified.py diff --git a/tests/test/test_timestampindex.py b/tests/test_timestampindex.py similarity index 100% rename from tests/test/test_timestampindex.py rename to tests/test_timestampindex.py diff --git a/tests/test/test_transcripts.py b/tests/test_transcripts.py similarity index 97% rename from tests/test/test_transcripts.py rename to tests/test_transcripts.py index 2efef03..9079833 100644 --- a/tests/test/test_transcripts.py +++ b/tests/test_transcripts.py @@ -21,6 +21,8 @@ webvtt_timestamp_to_seconds, ) +from conftest import CONFUSE_A_CAT_VTT, has_testdata_file, PARROT_SKETCH_VTT + def test_extract_speaker_from_text(): """Test speaker extraction from various text formats.""" @@ -67,12 +69,12 @@ def test_webvtt_timestamp_conversion(): @pytest.mark.skipif( - not os.path.exists("tests/testdata/Confuse-A-Cat.vtt"), + not has_testdata_file("Confuse-A-Cat.vtt"), reason="Test VTT file not found", ) def test_get_transcript_info(): """Test getting basic information from a VTT file.""" - vtt_file = "tests/testdata/Confuse-A-Cat.vtt" + vtt_file = CONFUSE_A_CAT_VTT # Test duration duration = get_transcript_duration(vtt_file) @@ -93,7 +95,7 @@ def conversation_settings( @pytest.mark.skipif( - not os.path.exists("tests/testdata/Confuse-A-Cat.vtt"), + not has_testdata_file("Confuse-A-Cat.vtt"), reason="Test VTT file not found", ) @pytest.mark.asyncio @@ -108,7 +110,7 @@ async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex from typeagent.transcripts.transcript_ingest import parse_voice_tags - vtt_file = "tests/testdata/Confuse-A-Cat.vtt" + vtt_file = CONFUSE_A_CAT_VTT # Use in-memory storage to avoid database cleanup issues settings = conversation_settings @@ -264,7 +266,7 @@ async def test_transcript_knowledge_extraction_slow( settings = ConversationSettings(embedding_model) # Parse first 5 captions from Parrot Sketch - vtt_file = "tests/testdata/Parrot_Sketch.vtt" + vtt_file = PARROT_SKETCH_VTT if not os.path.exists(vtt_file): pytest.skip(f"Test file {vtt_file} not found") diff --git a/tests/test/test_utils.py b/tests/test_utils.py similarity index 100% rename from tests/test/test_utils.py rename to tests/test_utils.py diff --git a/tests/test/test_vectorbase.py b/tests/test_vectorbase.py similarity index 100% rename from tests/test/test_vectorbase.py rename to tests/test_vectorbase.py diff --git a/tools/ingest_podcast.py b/tools/ingest_podcast.py index 90e80ec..6ff9cdc 100644 --- a/tools/ingest_podcast.py +++ b/tools/ingest_podcast.py @@ -6,7 +6,6 @@ from typeagent.knowpro.convsettings import ConversationSettings from typeagent.podcasts.podcast_ingest import ingest_podcast -DEFAULT_TRANSCRIPT = "testdata/Episode_53_AdrianTchaikovsky.txt" CHARS_PER_MINUTE = 1050 # My guess for average speech rate incl. overhead @@ -50,11 +49,16 @@ async def main(): if args.database is not None and args.json_output is not None: raise SystemExit("Please use at most one of --database and --json-output") if args.transcript is None: - if os.path.exists(DEFAULT_TRANSCRIPT): - args.transcript = DEFAULT_TRANSCRIPT - print("Reading default transcript:", DEFAULT_TRANSCRIPT) - else: - raise SystemExit("Please provide a transcript file to ingest") + raise SystemExit( + "Error: A transcript file is required.\n" + "Usage: python ingest_podcast.py \n" + "Example: python ingest_podcast.py path/to/transcript.vtt" + ) + if not os.path.exists(args.transcript): + raise SystemExit( + f"Error: Transcript file not found: {args.transcript}\n" + "Please verify the path exists and is accessible." + ) load_dotenv() diff --git a/tools/query.py b/tools/query.py index d11f643..83e4339 100644 --- a/tools/query.py +++ b/tools/query.py @@ -536,6 +536,23 @@ async def main(): args = parser.parse_args() fill_in_debug_defaults(parser, args) + # Validate required podcast argument + if args.podcast is None and args.database is None: + raise SystemExit( + "Error: Either --podcast or --database is required.\n" + "Usage: python query.py --podcast \n" + " or: python query.py --database \n" + "Example: python query.py --podcast tests/testdata/Episode_53_index" + ) + if args.podcast is not None: + index_file = args.podcast + "_index.json" + if not os.path.exists(index_file): + raise SystemExit( + f"Error: Podcast index file not found: {index_file}\n" + "Please verify the path exists and is accessible.\n" + "Note: The path should exclude the '_index.json' suffix." + ) + if args.logfire: utils.setup_logfire() @@ -925,27 +942,24 @@ def make_arg_parser(description: str) -> argparse.ArgumentParser: ), ) - default_podcast_file = "tests/testdata/Episode_53_AdrianTchaikovsky_index" parser.add_argument( "--podcast", type=str, - default=default_podcast_file, + default=None, help="Path to the podcast index files (excluding the '_index.json' suffix)", ) - default_qafile = "tests/testdata/Episode_53_Answer_results.json" explain_qa = "a list of questions and answers to test the full pipeline" parser.add_argument( "--qafile", type=str, - default=default_qafile, + default=None, help=f"Path to the Answer_results.json file ({explain_qa})", ) - default_srfile = "tests/testdata/Episode_53_Search_results.json" explain_sr = "a list of intermediate results from stages 1, 2 and 3" parser.add_argument( "--srfile", type=str, - default=default_srfile, + default=None, help=f"Path to the Search_results.json file ({explain_sr})", ) parser.add_argument(