Skip to content

Commit 7e72e85

Browse files
authored
Fixture testdatapath, refactor and unify testdatapath handling (#147)
Fixes #146 (see there for details).
1 parent baa699d commit 7e72e85

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+158
-47
lines changed

src/typeagent/mcp/server.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import argparse
88
from dataclasses import dataclass
9+
import os
910
import time
1011
from typing import Any
1112

@@ -28,6 +29,9 @@
2829
from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex
2930
from typeagent.storage.utils import create_storage_provider
3031

32+
# Example podcast index path for documentation and error messages
33+
_EXAMPLE_PODCAST_INDEX = "tests/testdata/Episode_53_AdrianTchaikovsky_index"
34+
3135

3236
class MCPTypeChatModel(typechat.TypeChatLanguageModel):
3337
"""TypeChat language model that uses MCP sampling API instead of direct API calls."""
@@ -142,7 +146,9 @@ async def make_context(
142146
entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
143147
)
144148

145-
query_context = await load_podcast_index_or_database(settings, dbname)
149+
query_context = await load_podcast_database_or_index(
150+
settings, dbname, _podcast_index
151+
)
146152

147153
# Use MCP-based model instead of one that requires API keys
148154
model = MCPTypeChatModel(session)
@@ -161,24 +167,33 @@ async def make_context(
161167
return context
162168

163169

164-
async def load_podcast_index_or_database(
170+
async def load_podcast_database_or_index(
165171
settings: ConversationSettings,
166172
dbname: str | None = None,
173+
podcast_index: str | None = None,
167174
) -> query.QueryEvalContext[podcast.PodcastMessage, Any]:
168-
if dbname is None:
169-
conversation = await podcast.Podcast.read_from_file(
170-
"tests/testdata/Episode_53_AdrianTchaikovsky_index", settings
171-
)
172-
else:
175+
if dbname is not None:
176+
# Load from SQLite database
173177
conversation = await podcast.Podcast.create(settings)
178+
elif podcast_index is not None:
179+
# Load from JSON index files
180+
conversation = await podcast.Podcast.read_from_file(podcast_index, settings)
181+
else:
182+
raise ValueError(
183+
"Either --database or --podcast-index must be specified. "
184+
"Use --podcast-index to specify the path to podcast index files "
185+
f"(e.g., '{_EXAMPLE_PODCAST_INDEX}')."
186+
)
174187
return query.QueryEvalContext(conversation)
175188

176189

177190
# Create an MCP server
178191
mcp = FastMCP("typagent")
179192

180-
# Global variable to store database path (set via command-line argument)
193+
# Global variables to store command-line arguments
194+
# (no other straightforward way to pass to tool handlers)
181195
_dbname: str | None = None
196+
_podcast_index: str | None = None
182197

183198

184199
@dataclass
@@ -245,12 +260,49 @@ async def query_conversation(
245260
"--database",
246261
type=str,
247262
default=None,
248-
help="Path to the SQLite database file (default: load from JSON file)",
263+
help="Path to a SQLite database file with pre-indexed podcast data",
264+
)
265+
parser.add_argument(
266+
"-p",
267+
"--podcast-index",
268+
type=str,
269+
default=None,
270+
help="Path to podcast index files (excluding '_data.json' suffix), "
271+
f"e.g., '{_EXAMPLE_PODCAST_INDEX}'",
249272
)
250273
args = parser.parse_args()
251274

252-
# Store database path in global variable (no other straightforward way to pass to tool)
275+
# Validate arguments
276+
if args.database is None and args.podcast_index is None:
277+
parser.error(
278+
"Either --database or --podcast-index is required.\n"
279+
"Example: python -m typeagent.mcp.server "
280+
f"--podcast-index {_EXAMPLE_PODCAST_INDEX}"
281+
)
282+
283+
if args.database is not None and args.podcast_index is not None:
284+
parser.error("Cannot specify both --database and --podcast-index")
285+
286+
# Validate file existence
287+
if args.database is not None and not os.path.exists(args.database):
288+
parser.error(
289+
f"Database file not found: {args.database}\n"
290+
"Please provide a valid path to an existing SQLite database."
291+
)
292+
293+
if args.podcast_index is not None:
294+
data_file = args.podcast_index + "_data.json"
295+
if not os.path.exists(data_file):
296+
parser.error(
297+
f"Podcast index file not found: {data_file}\n"
298+
"Please provide a valid path to podcast index files "
299+
"(without the '_data.json' suffix).\n"
300+
f"Example: {_EXAMPLE_PODCAST_INDEX}"
301+
)
302+
303+
# Store in global variables for tool handlers
253304
_dbname = args.database
305+
_podcast_index = args.podcast_index
254306

255307
# Use stdio transport for simplicity
256308
mcp.run(transport="stdio")
Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
from collections.abc import AsyncGenerator, Iterator
4+
from collections.abc import AsyncGenerator, Callable, Iterator
55
import os
6+
from pathlib import Path
67
import tempfile
78
from typing import Any
89

@@ -43,6 +44,37 @@
4344
MemorySemanticRefCollection,
4445
)
4546

47+
# --- Testdata path utilities ---
48+
# Locate the tests directory relative to this file
49+
_TESTS_DIR = Path(__file__).resolve().parent # tests/
50+
_TESTDATA_DIR = _TESTS_DIR / "testdata"
51+
_REPO_ROOT = _TESTS_DIR.parent
52+
53+
54+
def get_testdata_path(filename: str) -> str:
55+
"""Return absolute path to a file in tests/testdata/."""
56+
return str(_TESTDATA_DIR / filename)
57+
58+
59+
def get_repo_root() -> Path:
60+
"""Return the repository root path."""
61+
return _REPO_ROOT
62+
63+
64+
def has_testdata_file(filename: str) -> bool:
65+
"""Check if a testdata file exists (for use in skipif conditions)."""
66+
return (_TESTDATA_DIR / filename).exists()
67+
68+
69+
# Commonly used test files as constants
70+
CONFUSE_A_CAT_VTT = get_testdata_path("Confuse-A-Cat.vtt")
71+
PARROT_SKETCH_VTT = get_testdata_path("Parrot_Sketch.vtt")
72+
FAKE_PODCAST_TXT = get_testdata_path("FakePodcast.txt")
73+
EPISODE_53_INDEX = get_testdata_path("Episode_53_AdrianTchaikovsky_index")
74+
EPISODE_53_TRANSCRIPT = get_testdata_path("Episode_53_AdrianTchaikovsky.txt")
75+
EPISODE_53_ANSWERS = get_testdata_path("Episode_53_Answer_results.json")
76+
EPISODE_53_SEARCH = get_testdata_path("Episode_53_Search_results.json")
77+
4678

4779
@pytest.fixture(scope="session")
4880
def needs_auth() -> None:
@@ -63,6 +95,17 @@ def embedding_model() -> AsyncEmbeddingModel:
6395
return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
6496

6597

98+
@pytest.fixture(scope="session")
99+
def testdata_path() -> Callable[[str], str]:
100+
"""Fixture returning a function to get absolute paths to testdata files.
101+
102+
Usage:
103+
def test_something(testdata_path):
104+
path = testdata_path("Confuse-A-Cat.vtt")
105+
"""
106+
return get_testdata_path
107+
108+
66109
@pytest.fixture
67110
def temp_dir() -> Iterator[str]:
68111
with tempfile.TemporaryDirectory() as dir:
File renamed without changes.
Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import argparse
55
import asyncio
6-
import os
76
import textwrap
87
import time
98

@@ -12,23 +11,21 @@
1211
from typeagent.knowpro.interfaces import ScoredSemanticRefOrdinal
1312
from typeagent.podcasts import podcast
1413

15-
tests_dir = os.path.dirname(__file__)
16-
root_dir = os.path.dirname(tests_dir)
17-
DEFAULT_FILE = os.path.join(root_dir, "testdata", "Episode_53_AdrianTchaikovsky_index")
14+
from conftest import EPISODE_53_INDEX
1815

1916
parser = argparse.ArgumentParser()
2017
parser.add_argument(
2118
"filename",
2219
nargs="?",
2320
type=str,
24-
default=DEFAULT_FILE,
21+
default=EPISODE_53_INDEX,
2522
)
2623

2724

2825
def test_main(really_needs_auth: None):
2926
# auth is needed because we use embeddings.
3027
# TODO: Only use the embeddings loaded from the file and cached.
31-
asyncio.run(main(DEFAULT_FILE))
28+
asyncio.run(main(EPISODE_53_INDEX))
3229

3330

3431
async def main(filename_prefix: str):

0 commit comments

Comments
 (0)