Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions src/typeagent/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
from dataclasses import dataclass
import os
import time
from typing import Any

Expand All @@ -28,6 +29,9 @@
from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex
from typeagent.storage.utils import create_storage_provider

# Example podcast index path for documentation and error messages
_EXAMPLE_PODCAST_INDEX = "tests/testdata/Episode_53_AdrianTchaikovsky_index"


class MCPTypeChatModel(typechat.TypeChatLanguageModel):
"""TypeChat language model that uses MCP sampling API instead of direct API calls."""
Expand Down Expand Up @@ -142,7 +146,9 @@ async def make_context(
entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
)

query_context = await load_podcast_index_or_database(settings, dbname)
query_context = await load_podcast_database_or_index(
settings, dbname, _podcast_index
)

# Use MCP-based model instead of one that requires API keys
model = MCPTypeChatModel(session)
Expand All @@ -161,24 +167,33 @@ async def make_context(
return context


async def load_podcast_index_or_database(
async def load_podcast_database_or_index(
settings: ConversationSettings,
dbname: str | None = None,
podcast_index: str | None = None,
) -> query.QueryEvalContext[podcast.PodcastMessage, Any]:
if dbname is None:
conversation = await podcast.Podcast.read_from_file(
"tests/testdata/Episode_53_AdrianTchaikovsky_index", settings
)
else:
if dbname is not None:
# Load from SQLite database
conversation = await podcast.Podcast.create(settings)
elif podcast_index is not None:
# Load from JSON index files
conversation = await podcast.Podcast.read_from_file(podcast_index, settings)
else:
raise ValueError(
"Either --database or --podcast-index must be specified. "
"Use --podcast-index to specify the path to podcast index files "
f"(e.g., '{_EXAMPLE_PODCAST_INDEX}')."
Comment on lines +183 to +185
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, excellent. In practice the server almost always should use --database, which should be pre-filled by running one of the ingest tools.

)
return query.QueryEvalContext(conversation)


# Create an MCP server
mcp = FastMCP("typagent")

# Global variable to store database path (set via command-line argument)
# Global variables to store command-line arguments
# (no other straightforward way to pass to tool handlers)
_dbname: str | None = None
_podcast_index: str | None = None


@dataclass
Expand Down Expand Up @@ -245,12 +260,49 @@ async def query_conversation(
"--database",
type=str,
default=None,
help="Path to the SQLite database file (default: load from JSON file)",
help="Path to a SQLite database file with pre-indexed podcast data",
)
parser.add_argument(
"-p",
"--podcast-index",
type=str,
default=None,
help="Path to podcast index files (excluding '_data.json' suffix), "
f"e.g., '{_EXAMPLE_PODCAST_INDEX}'",
Comment on lines +265 to +271
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point in the future we should be able to remove --podcast-index altogether, but not yet (and the test uses it).

)
args = parser.parse_args()

# Store database path in global variable (no other straightforward way to pass to tool)
# Validate arguments
if args.database is None and args.podcast_index is None:
parser.error(
"Either --database or --podcast-index is required.\n"
"Example: python -m typeagent.mcp.server "
f"--podcast-index {_EXAMPLE_PODCAST_INDEX}"
)

if args.database is not None and args.podcast_index is not None:
parser.error("Cannot specify both --database and --podcast-index")

# Validate file existence
if args.database is not None and not os.path.exists(args.database):
parser.error(
f"Database file not found: {args.database}\n"
"Please provide a valid path to an existing SQLite database."
)

if args.podcast_index is not None:
data_file = args.podcast_index + "_data.json"
if not os.path.exists(data_file):
parser.error(
f"Podcast index file not found: {data_file}\n"
"Please provide a valid path to podcast index files "
"(without the '_data.json' suffix).\n"
f"Example: {_EXAMPLE_PODCAST_INDEX}"
)

# Store in global variables for tool handlers
_dbname = args.database
_podcast_index = args.podcast_index

# Use stdio transport for simplicity
mcp.run(transport="stdio")
45 changes: 44 additions & 1 deletion tests/test/conftest.py → tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from collections.abc import AsyncGenerator, Iterator
from collections.abc import AsyncGenerator, Callable, Iterator
import os
from pathlib import Path
import tempfile
from typing import Any

Expand Down Expand Up @@ -43,6 +44,37 @@
MemorySemanticRefCollection,
)

# --- Testdata path utilities ---
# Locate the tests directory relative to this file
_TESTS_DIR = Path(__file__).resolve().parent # tests/
_TESTDATA_DIR = _TESTS_DIR / "testdata"
_REPO_ROOT = _TESTS_DIR.parent


def get_testdata_path(filename: str) -> str:
"""Return absolute path to a file in tests/testdata/."""
return str(_TESTDATA_DIR / filename)


def get_repo_root() -> Path:
"""Return the repository root path."""
return _REPO_ROOT


def has_testdata_file(filename: str) -> bool:
"""Check if a testdata file exists (for use in skipif conditions)."""
return (_TESTDATA_DIR / filename).exists()


# Commonly used test files as constants
CONFUSE_A_CAT_VTT = get_testdata_path("Confuse-A-Cat.vtt")
PARROT_SKETCH_VTT = get_testdata_path("Parrot_Sketch.vtt")
FAKE_PODCAST_TXT = get_testdata_path("FakePodcast.txt")
EPISODE_53_INDEX = get_testdata_path("Episode_53_AdrianTchaikovsky_index")
EPISODE_53_TRANSCRIPT = get_testdata_path("Episode_53_AdrianTchaikovsky.txt")
EPISODE_53_ANSWERS = get_testdata_path("Episode_53_Answer_results.json")
EPISODE_53_SEARCH = get_testdata_path("Episode_53_Search_results.json")


@pytest.fixture(scope="session")
def needs_auth() -> None:
Expand All @@ -63,6 +95,17 @@ def embedding_model() -> AsyncEmbeddingModel:
return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)


@pytest.fixture(scope="session")
def testdata_path() -> Callable[[str], str]:
"""Fixture returning a function to get absolute paths to testdata files.

Usage:
def test_something(testdata_path):
path = testdata_path("Confuse-A-Cat.vtt")
"""
return get_testdata_path


@pytest.fixture
def temp_dir() -> Iterator[str]:
with tempfile.TemporaryDirectory() as dir:
Expand Down
File renamed without changes.
File renamed without changes.
9 changes: 3 additions & 6 deletions tests/test/test_demo.py → tests/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import argparse
import asyncio
import os
import textwrap
import time

Expand All @@ -12,23 +11,21 @@
from typeagent.knowpro.interfaces import ScoredSemanticRefOrdinal
from typeagent.podcasts import podcast

tests_dir = os.path.dirname(__file__)
root_dir = os.path.dirname(tests_dir)
DEFAULT_FILE = os.path.join(root_dir, "testdata", "Episode_53_AdrianTchaikovsky_index")
from conftest import EPISODE_53_INDEX

parser = argparse.ArgumentParser()
parser.add_argument(
"filename",
nargs="?",
type=str,
default=DEFAULT_FILE,
default=EPISODE_53_INDEX,
)


def test_main(really_needs_auth: None):
# auth is needed because we use embeddings.
# TODO: Only use the embeddings loaded from the file and cached.
asyncio.run(main(DEFAULT_FILE))
asyncio.run(main(EPISODE_53_INDEX))


async def main(filename_prefix: str):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
)
from typeagent.transcripts.transcript_ingest import ingest_vtt_transcript

tests_dir = os.path.dirname(__file__)
root_dir = os.path.dirname(tests_dir)
from conftest import CONFUSE_A_CAT_VTT, PARROT_SKETCH_VTT


@pytest.mark.asyncio
Expand Down Expand Up @@ -144,9 +143,8 @@ async def test_incremental_index_with_vtt_files():
# First VTT file ingestion
print("\n=== Import first VTT file ===")
# Import the first transcript
DEFAULT_FILE = os.path.join(root_dir, "testdata", "Confuse-A-Cat.vtt")
transcript1 = await ingest_vtt_transcript(
DEFAULT_FILE,
CONFUSE_A_CAT_VTT,
settings,
dbname=db_path,
)
Expand All @@ -169,9 +167,8 @@ async def test_incremental_index_with_vtt_files():
settings2.semantic_ref_index_settings.auto_extract_knowledge = False

# Ingest the second transcript
DEFAULT_FILE = os.path.join(root_dir, "testdata", "Parrot_Sketch.vtt")
transcript2 = await ingest_vtt_transcript(
DEFAULT_FILE,
PARROT_SKETCH_VTT,
settings2,
dbname=db_path,
)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 7 additions & 3 deletions tests/test/test_mcp_server.py → tests/test_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@
from mcp.shared.context import RequestContext
from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent

from conftest import EPISODE_53_INDEX


@pytest.fixture
def server_params() -> StdioServerParameters:
"""Create MCP server parameters with minimal environment."""
env = {}
"""Create MCP server parameters with environment inherited from parent process."""
# Start with the full environment - subprocess needs PATH, PYTHONPATH, etc.
env = dict(os.environ)
# Coverage support
if "COVERAGE_PROCESS_START" in os.environ:
env["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_PROCESS_START"]

return StdioServerParameters(
command=sys.executable,
args=["-m", "typeagent.mcp.server"],
args=["-m", "typeagent.mcp.server", "--podcast-index", EPISODE_53_INDEX],
env=env,
)

Expand Down
File renamed without changes.
File renamed without changes.
6 changes: 2 additions & 4 deletions tests/test/test_podcasts.py → tests/test_podcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from typeagent.podcasts import podcast_ingest
from typeagent.podcasts.podcast import Podcast

tests_dir = os.path.dirname(__file__)
root_dir = os.path.dirname(tests_dir)
DEFAULT_FILE = os.path.join(root_dir, "testdata", "FakePodcast.txt")
from conftest import FAKE_PODCAST_TXT


@pytest.mark.asyncio
Expand All @@ -25,7 +23,7 @@ async def test_ingest_podcast(
# Import the podcast
settings = ConversationSettings(embedding_model)
pod = await podcast_ingest.ingest_podcast(
DEFAULT_FILE,
FAKE_PODCAST_TXT,
settings,
None,
Datetime.now(timezone.utc), # Use timezone-aware datetime
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
12 changes: 7 additions & 5 deletions tests/test/test_transcripts.py → tests/test_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
webvtt_timestamp_to_seconds,
)

from conftest import CONFUSE_A_CAT_VTT, has_testdata_file, PARROT_SKETCH_VTT


def test_extract_speaker_from_text():
"""Test speaker extraction from various text formats."""
Expand Down Expand Up @@ -67,12 +69,12 @@ def test_webvtt_timestamp_conversion():


@pytest.mark.skipif(
not os.path.exists("tests/testdata/Confuse-A-Cat.vtt"),
not has_testdata_file("Confuse-A-Cat.vtt"),
reason="Test VTT file not found",
)
def test_get_transcript_info():
"""Test getting basic information from a VTT file."""
vtt_file = "tests/testdata/Confuse-A-Cat.vtt"
vtt_file = CONFUSE_A_CAT_VTT

# Test duration
duration = get_transcript_duration(vtt_file)
Expand All @@ -93,7 +95,7 @@ def conversation_settings(


@pytest.mark.skipif(
not os.path.exists("tests/testdata/Confuse-A-Cat.vtt"),
not has_testdata_file("Confuse-A-Cat.vtt"),
reason="Test VTT file not found",
)
@pytest.mark.asyncio
Expand All @@ -108,7 +110,7 @@ async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings
from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex
from typeagent.transcripts.transcript_ingest import parse_voice_tags

vtt_file = "tests/testdata/Confuse-A-Cat.vtt"
vtt_file = CONFUSE_A_CAT_VTT

# Use in-memory storage to avoid database cleanup issues
settings = conversation_settings
Expand Down Expand Up @@ -264,7 +266,7 @@ async def test_transcript_knowledge_extraction_slow(
settings = ConversationSettings(embedding_model)

# Parse first 5 captions from Parrot Sketch
vtt_file = "tests/testdata/Parrot_Sketch.vtt"
vtt_file = PARROT_SKETCH_VTT
if not os.path.exists(vtt_file):
pytest.skip(f"Test file {vtt_file} not found")

Expand Down
File renamed without changes.
File renamed without changes.
16 changes: 10 additions & 6 deletions tools/ingest_podcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typeagent.knowpro.convsettings import ConversationSettings
from typeagent.podcasts.podcast_ingest import ingest_podcast

DEFAULT_TRANSCRIPT = "testdata/Episode_53_AdrianTchaikovsky.txt"
CHARS_PER_MINUTE = 1050 # My guess for average speech rate incl. overhead


Expand Down Expand Up @@ -50,11 +49,16 @@ async def main():
if args.database is not None and args.json_output is not None:
raise SystemExit("Please use at most one of --database and --json-output")
if args.transcript is None:
if os.path.exists(DEFAULT_TRANSCRIPT):
args.transcript = DEFAULT_TRANSCRIPT
print("Reading default transcript:", DEFAULT_TRANSCRIPT)
else:
raise SystemExit("Please provide a transcript file to ingest")
raise SystemExit(
"Error: A transcript file is required.\n"
"Usage: python ingest_podcast.py <transcript_file>\n"
"Example: python ingest_podcast.py path/to/transcript.vtt"
)
if not os.path.exists(args.transcript):
raise SystemExit(
f"Error: Transcript file not found: {args.transcript}\n"
"Please verify the path exists and is accessible."
)

load_dotenv()

Expand Down
Loading