From 20adfada7297661c36a031cb9e27b9f027e212df Mon Sep 17 00:00:00 2001 From: Nels Numan Date: Thu, 19 Feb 2026 12:50:20 +0000 Subject: [PATCH] Add real-time stdio streaming runtime with framed PCM protocol - Add moshi.stdio for real-time audio I/O over stdin/stdout - Implement length-prefixed binary protocol (u32_le + payload) with kinds: - 0x00 handshake - 0x01 audio (PCM16 mono) - 0x02 text - 0x05 error - 0x06 ping - Mirror server/offline model flow (load/warmup/prompts/streaming state) - Add robust packet parser, PCM frame buffer, and binary-safe stdout writer - Add EOF drain behavior to flush delayed model outputs (--eof-drain-frames, default 32) - Wire CLI entrypoint moshi-stdio in pyproject.toml - Add stdio test with existing assets/test files for quick validation --- moshi/moshi/stdio.py | 557 +++++++++++++++++++++++++++++ moshi/pyproject.toml | 3 +- moshi/test/stdio_realtime_check.py | 185 ++++++++++ 3 files changed, 744 insertions(+), 1 deletion(-) create mode 100644 moshi/moshi/stdio.py create mode 100644 moshi/test/stdio_realtime_check.py diff --git a/moshi/moshi/stdio.py b/moshi/moshi/stdio.py new file mode 100644 index 00000000..93e5ab69 --- /dev/null +++ b/moshi/moshi/stdio.py @@ -0,0 +1,557 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +""" +Real-time stdin/stdout streaming entrypoint for PersonaPlex. + +Protocol: +- Transport frame: [u32 little-endian payload_len][payload bytes] +- payload[0] is the message kind: + 0x00 handshake (stdout) + 0x01 audio (stdin/stdout), PCM16LE mono at model sample rate + 0x02 text (stdout), UTF-8 + 0x05 error (stdout), UTF-8 + 0x06 ping (stdin), ignored +""" + +from __future__ import annotations + +import argparse +import contextlib +import os +from pathlib import Path +import struct +import sys +import tarfile +from dataclasses import dataclass +from typing import BinaryIO, Optional + +from huggingface_hub import hf_hub_download +import numpy as np +import sentencepiece +import torch + +from .client_utils import make_log +from .models import loaders, LMGen, MimiModel + +KIND_HANDSHAKE = 0x00 +KIND_AUDIO = 0x01 +KIND_TEXT = 0x02 +KIND_ERROR = 0x05 +KIND_PING = 0x06 + + +def log(level: str, msg: str): + print(make_log(level, msg), file=sys.stderr, flush=True) + + +def seed_all(seed: int): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + import random + + random.seed(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = False + + +def wrap_with_system_tags(text: str) -> str: + cleaned = text.strip() + if cleaned.startswith("") and cleaned.endswith(""): + return cleaned + return f" {cleaned} " + + +def _get_voice_prompt_dir(voice_prompt_dir: Optional[str], hf_repo: str) -> Optional[str]: + if voice_prompt_dir is not None: + return voice_prompt_dir + + log("info", "retrieving voice prompts") + voices_tgz = hf_hub_download(hf_repo, "voices.tgz") + voices_tgz = Path(voices_tgz) + voices_dir = voices_tgz.parent / "voices" + + if not voices_dir.exists(): + log("info", f"extracting {voices_tgz} to {voices_dir}") + with tarfile.open(voices_tgz, "r:gz") as tar: + tar.extractall(path=voices_tgz.parent) + + if not voices_dir.exists(): + raise RuntimeError("voices.tgz did not contain a 'voices/' directory") + return str(voices_dir) + + +def warmup(mimi: MimiModel, other_mimi: MimiModel, lm_gen: LMGen, device: torch.device, frame_size: int): + for _ in range(4): + chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=device) + with contextlib.redirect_stdout(sys.stderr): + codes = mimi.encode(chunk) + _ = other_mimi.encode(chunk) + for c in range(codes.shape[-1]): + tokens = lm_gen.step(codes[:, :, c : c + 1]) + if tokens is None: + continue + _ = mimi.decode(tokens[:, 1:9]) + _ = other_mimi.decode(tokens[:, 1:9]) + if device.type == "cuda": + torch.cuda.synchronize() + + +def pcm16le_bytes_to_float32(pcm16le: bytes) -> np.ndarray: + if len(pcm16le) % 2 != 0: + raise ValueError("PCM16 payload must contain an even number of bytes.") + pcm_i16 = np.frombuffer(pcm16le, dtype=" bytes: + clipped = np.clip(audio, -1.0, 1.0) + pcm_i16 = np.round(clipped * 32767.0).astype(" bytes: + if kind < 0 or kind > 255: + raise ValueError(f"Invalid packet kind: {kind}") + frame_payload = bytes([kind]) + payload + return struct.pack(" bool: + return len(self._buffer) > 0 + + def feed(self, data: bytes) -> list[bytes]: + if data: + self._buffer.extend(data) + + packets: list[bytes] = [] + while True: + if len(self._buffer) < 4: + break + payload_len = struct.unpack_from(" self.max_payload_bytes: + raise ValueError( + f"Payload length {payload_len} exceeds max_payload_bytes={self.max_payload_bytes}." + ) + end_idx = 4 + payload_len + if len(self._buffer) < end_idx: + break + packets.append(bytes(self._buffer[4:end_idx])) + del self._buffer[:end_idx] + return packets + + +class PacketWriter: + def __init__(self, stream: BinaryIO, max_payload_bytes: int): + self.stream = stream + self.max_payload_bytes = max_payload_bytes + self.closed = False + + def send(self, kind: int, payload: bytes = b"") -> bool: + if self.closed: + return False + if len(payload) + 1 > self.max_payload_bytes: + raise ValueError( + f"payload for kind={kind} exceeds max_payload_bytes={self.max_payload_bytes}: " + f"{len(payload) + 1} bytes" + ) + packet = encode_packet(kind, payload) + try: + self.stream.write(packet) + self.stream.flush() + return True + except (BrokenPipeError, OSError): + self.closed = True + log("warning", "stdout pipe closed; stopping.") + return False + + def send_error(self, message: str) -> bool: + return self.send(KIND_ERROR, message.encode("utf-8", errors="replace")) + + +class PCMFrameBuffer: + """Buffers PCM16 bytes and emits exact model-sized float32 frames.""" + + def __init__(self, frame_size_samples: int): + self.frame_size_samples = frame_size_samples + self.frame_size_bytes = frame_size_samples * 2 + self._buffer = bytearray() + + def append_pcm16(self, payload: bytes): + if len(payload) % 2 != 0: + raise ValueError("audio payload length must be an even number of bytes.") + self._buffer.extend(payload) + + def pop_complete_frames(self) -> list[np.ndarray]: + frames: list[np.ndarray] = [] + while len(self._buffer) >= self.frame_size_bytes: + frame_bytes = bytes(self._buffer[: self.frame_size_bytes]) + del self._buffer[: self.frame_size_bytes] + frames.append(pcm16le_bytes_to_float32(frame_bytes)) + return frames + + def flush_padded_frame(self) -> Optional[np.ndarray]: + if len(self._buffer) == 0: + return None + if len(self._buffer) % 2 != 0: + # Should not happen if append_pcm16 validated, but keep this safe. + self._buffer = self._buffer[:-1] + padded = bytes(self._buffer) + b"\x00" * (self.frame_size_bytes - len(self._buffer)) + self._buffer.clear() + return pcm16le_bytes_to_float32(padded) + + +@dataclass +class RuntimeState: + mimi: MimiModel + other_mimi: MimiModel + lm_gen: LMGen + text_tokenizer: sentencepiece.SentencePieceProcessor + device: torch.device + frame_size: int + + +def _init_runtime( + voice_prompt_path: str, + text_prompt: str, + tokenizer_path: Optional[str], + moshi_weight: Optional[str], + mimi_weight: Optional[str], + hf_repo: str, + device: torch.device, + seed: int, + temp_audio: float, + temp_text: float, + topk_audio: int, + topk_text: int, + greedy: bool, + cpu_offload: bool, + save_voice_prompt_embeddings: bool, +) -> RuntimeState: + if seed != -1: + seed_all(seed) + + hf_hub_download(hf_repo, "config.json") + + log("info", "loading mimi") + if mimi_weight is None: + mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME) # type: ignore + mimi = loaders.get_mimi(mimi_weight, device) + other_mimi = loaders.get_mimi(mimi_weight, device) + log("info", "mimi loaded") + + if tokenizer_path is None: + tokenizer_path = hf_hub_download(hf_repo, loaders.TEXT_TOKENIZER_NAME) # type: ignore + text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path) # type: ignore + + log("info", "loading moshi") + if moshi_weight is None: + moshi_weight = hf_hub_download(hf_repo, loaders.MOSHI_NAME) # type: ignore + lm = loaders.get_moshi_lm(moshi_weight, device=device, cpu_offload=cpu_offload) + lm.eval() + log("info", "moshi loaded") + + frame_size = int(mimi.sample_rate / mimi.frame_rate) + lm_gen = LMGen( + lm, + audio_silence_frame_cnt=int(0.5 * mimi.frame_rate), + sample_rate=mimi.sample_rate, + device=device, + frame_rate=mimi.frame_rate, + save_voice_prompt_embeddings=save_voice_prompt_embeddings, + use_sampling=not greedy, + temp=temp_audio, + temp_text=temp_text, + top_k=topk_audio, + top_k_text=topk_text, + ) + + mimi.streaming_forever(1) + other_mimi.streaming_forever(1) + lm_gen.streaming_forever(1) + + log("info", "warming up model") + warmup(mimi, other_mimi, lm_gen, device, frame_size) + + if voice_prompt_path.endswith(".pt"): + lm_gen.load_voice_prompt_embeddings(voice_prompt_path) + else: + lm_gen.load_voice_prompt(voice_prompt_path) + lm_gen.text_prompt_tokens = ( + text_tokenizer.encode(wrap_with_system_tags(text_prompt)) if len(text_prompt) > 0 else None + ) + + mimi.reset_streaming() + other_mimi.reset_streaming() + lm_gen.reset_streaming() + with contextlib.redirect_stdout(sys.stderr): + lm_gen.step_system_prompts(mimi) + mimi.reset_streaming() + log("info", "done with system prompts") + + return RuntimeState( + mimi=mimi, + other_mimi=other_mimi, + lm_gen=lm_gen, + text_tokenizer=text_tokenizer, + device=device, + frame_size=frame_size, + ) + + +def _emit_model_step_output(state: RuntimeState, writer: PacketWriter, frame: np.ndarray) -> bool: + chunk = torch.from_numpy(frame).to(device=state.device)[None, None] + with contextlib.redirect_stdout(sys.stderr): + codes = state.mimi.encode(chunk) + _ = state.other_mimi.encode(chunk) + + for c in range(codes.shape[-1]): + tokens = state.lm_gen.step(codes[:, :, c : c + 1]) + if tokens is None: + continue + decoded = state.mimi.decode(tokens[:, 1:9]) + _ = state.other_mimi.decode(tokens[:, 1:9]) + + pcm = decoded.detach().cpu().numpy()[0, 0] + if not writer.send(KIND_AUDIO, float32_to_pcm16le_bytes(pcm)): + return False + + text_token = tokens[0, 0, 0].item() + if text_token not in (0, 3): + text_piece = state.text_tokenizer.id_to_piece(text_token) # type: ignore + text_piece = text_piece.replace("▁", " ") + if not writer.send(KIND_TEXT, text_piece.encode("utf-8")): + return False + return True + + +def run_stdio_stream( + state: RuntimeState, + writer: PacketWriter, + stdin_stream: BinaryIO, + read_size: int, + max_payload_bytes: int, + eof_drain_frames: int, +) -> int: + parser = LengthPrefixedParser(max_payload_bytes=max_payload_bytes) + pcm_buffer = PCMFrameBuffer(frame_size_samples=state.frame_size) + + if not writer.send(KIND_HANDSHAKE): + return 1 + + while True: + chunk = stdin_stream.read(read_size) + if not chunk: + break + try: + messages = parser.feed(chunk) + except ValueError as exc: + err_msg = f"protocol parse error: {exc}" + log("error", err_msg) + writer.send_error(err_msg) + return 2 + + for message in messages: + if not message: + msg = "received empty payload." + log("warning", msg) + writer.send_error(msg) + continue + kind = message[0] + payload = message[1:] + if kind == KIND_AUDIO: + try: + pcm_buffer.append_pcm16(payload) + except ValueError as exc: + err_msg = f"invalid audio payload: {exc}" + log("warning", err_msg) + writer.send_error(err_msg) + continue + for frame in pcm_buffer.pop_complete_frames(): + if not _emit_model_step_output(state, writer, frame): + return 0 + elif kind == KIND_PING: + continue + else: + msg = f"unknown message kind {kind}" + log("warning", msg) + writer.send_error(msg) + + if parser.has_pending_data: + msg = "stdin ended with incomplete frame payload." + log("warning", msg) + writer.send_error(msg) + + final_frame = pcm_buffer.flush_padded_frame() + if final_frame is not None: + if not _emit_model_step_output(state, writer, final_frame): + return 0 + + if eof_drain_frames > 0: + log("info", f"draining EOF with {eof_drain_frames} silence frames") + zero_frame = np.zeros((state.frame_size,), dtype=np.float32) + for _ in range(eof_drain_frames): + if not _emit_model_step_output(state, writer, zero_frame): + return 0 + return 0 + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Real-time stdin/stdout streaming for PersonaPlex. " + "Input and output use [u32 payload_len][payload] framing." + ) + ) + parser.add_argument( + "--voice-prompt", + required=True, + type=str, + help="Voice prompt filename (basename) inside --voice-prompt-dir (e.g. 'NATM1.pt').", + ) + parser.add_argument( + "--voice-prompt-dir", + type=str, + help=( + "Directory containing voice prompt files. " + "If omitted, voices.tgz is downloaded from HF and extracted." + ), + ) + parser.add_argument( + "--text-prompt", + default=( + "You are a wise and friendly teacher. " + "Answer questions or provide advice in a clear and engaging way." + ), + type=str, + help="Text prompt.", + ) + parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") + parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.") + parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.") + parser.add_argument( + "--hf-repo", + type=str, + default=loaders.DEFAULT_REPO, + help="HF repo to look into (defaults to pre-trained model repo).", + ) + parser.add_argument("--temp-audio", type=float, default=0.8, help="Audio sampling temperature.") + parser.add_argument("--temp-text", type=float, default=0.7, help="Text sampling temperature.") + parser.add_argument("--topk-audio", type=int, default=250, help="Audio top-k sampling.") + parser.add_argument("--topk-text", type=int, default=25, help="Text top-k sampling.") + parser.add_argument("--greedy", action="store_true", help="Disable sampling (greedy decoding).") + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device on which to run, defaults to 'cuda'.", + ) + parser.add_argument( + "--cpu-offload", + action="store_true", + help="Offload LM model layers to CPU when GPU memory is insufficient.", + ) + parser.add_argument("--seed", type=int, default=-1, help="Seed for reproducibility (-1 disables).") + parser.add_argument( + "--read-size", + type=int, + default=4096, + help="Number of bytes per stdin read iteration.", + ) + parser.add_argument( + "--max-payload-bytes", + type=int, + default=8 * 1024 * 1024, + help="Maximum payload bytes for any framed message (after u32 length).", + ) + parser.add_argument( + "--eof-drain-frames", + type=int, + default=32, + help=( + "Extra silence frames to process after stdin EOF. " + "Default: 32 (~2.56s at 12.5fps). Use -1 for auto (LM max delay), 0 to disable." + ), + ) + args = parser.parse_args() + + voice_prompt_dir = _get_voice_prompt_dir(args.voice_prompt_dir, args.hf_repo) + if not os.path.exists(voice_prompt_dir): + raise FileNotFoundError(f"voice_prompt_dir does not exist: {voice_prompt_dir}") + voice_prompt_path = os.path.join(voice_prompt_dir, args.voice_prompt) + if not os.path.exists(voice_prompt_path): + raise FileNotFoundError( + f"Voice prompt '{args.voice_prompt}' not found in " + f"'{voice_prompt_dir}' (resolved: {voice_prompt_path})" + ) + + wire_stdout = sys.stdout.buffer + sys.stdout = sys.stderr + writer = PacketWriter(stream=wire_stdout, max_payload_bytes=args.max_payload_bytes) + + device = torch.device(args.device) + with torch.no_grad(): + state = _init_runtime( + voice_prompt_path=voice_prompt_path, + text_prompt=args.text_prompt, + tokenizer_path=args.tokenizer, + moshi_weight=args.moshi_weight, + mimi_weight=args.mimi_weight, + hf_repo=args.hf_repo, + device=device, + seed=args.seed, + temp_audio=args.temp_audio, + temp_text=args.temp_text, + topk_audio=args.topk_audio, + topk_text=args.topk_text, + greedy=bool(args.greedy), + cpu_offload=args.cpu_offload, + save_voice_prompt_embeddings=False, + ) + eof_drain_frames = ( + int(state.lm_gen.max_delay) if int(args.eof_drain_frames) < 0 else int(args.eof_drain_frames) + ) + code = run_stdio_stream( + state=state, + writer=writer, + stdin_stream=sys.stdin.buffer, + read_size=args.read_size, + max_payload_bytes=args.max_payload_bytes, + eof_drain_frames=eof_drain_frames, + ) + raise SystemExit(code) + + +if __name__ == "__main__": + main() diff --git a/moshi/pyproject.toml b/moshi/pyproject.toml index ead71e9c..824e897f 100644 --- a/moshi/pyproject.toml +++ b/moshi/pyproject.toml @@ -22,10 +22,11 @@ readme = "README.md" [project.scripts] moshi-server = "moshi.server:main" moshi-offline = "moshi.offline:main" +moshi-stdio = "moshi.stdio:main" [tool.setuptools.dynamic] version = {attr = "moshi.__version__"} [build-system] requires = ["setuptools"] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" diff --git a/moshi/test/stdio_realtime_check.py b/moshi/test/stdio_realtime_check.py new file mode 100644 index 00000000..6b85609e --- /dev/null +++ b/moshi/test/stdio_realtime_check.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +""" +Concise check for moshi.stdio using assets in assets/test. + +Default behavior: validate framing roundtrip (no model run). +Runs end-to-end moshi.stdio by default. +""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path +import struct +import subprocess +import sys +import wave + +KIND_AUDIO = 0x01 + + +def encode_packet(kind: int, payload: bytes = b"") -> bytes: + frame_payload = bytes([kind]) + payload + return struct.pack(" list[bytes]: + out: list[bytes] = [] + offset = 0 + n = len(data) + while True: + if offset + 4 > n: + if offset == n: + return out + raise ValueError(f"incomplete frame header at offset {offset}") + payload_len = struct.unpack_from(" max_payload_bytes: + raise ValueError(f"payload_len={payload_len} exceeds {max_payload_bytes} at offset {offset}") + offset += 4 + if offset + payload_len > n: + raise ValueError(f"incomplete frame payload at offset {offset - 4}") + out.append(data[offset : offset + payload_len]) + offset += payload_len + + +def read_wav_pcm16_mono(path: Path) -> tuple[bytes, int]: + with wave.open(str(path), "rb") as wav_file: + channels = wav_file.getnchannels() + sample_width = wav_file.getsampwidth() + sample_rate = wav_file.getframerate() + pcm_bytes = wav_file.readframes(wav_file.getnframes()) + if channels != 1: + raise ValueError(f"expected mono wav; got channels={channels}") + if sample_width != 2: + raise ValueError(f"expected 16-bit wav; got sample_width={sample_width}") + return pcm_bytes, sample_rate + + +def write_wav_pcm16_mono(path: Path, pcm_bytes: bytes, sample_rate: int): + with wave.open(str(path), "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(sample_rate) + wav_file.writeframes(pcm_bytes) + + +def framing_roundtrip(pcm_bytes: bytes, sample_rate: int, chunk_bytes: int, out_wav: Path | None): + framed = bytearray() + for i in range(0, len(pcm_bytes), chunk_bytes): + framed.extend(encode_packet(KIND_AUDIO, pcm_bytes[i : i + chunk_bytes])) + + out_audio = bytearray() + for payload in decode_packets(bytes(framed)): + if payload[0] != KIND_AUDIO: + continue + out_audio.extend(payload[1:]) + + if bytes(out_audio) != pcm_bytes: + raise AssertionError("framing roundtrip mismatch") + if out_wav is not None: + write_wav_pcm16_mono(out_wav, bytes(out_audio), sample_rate) + + +def run_e2e( + pcm_bytes: bytes, + chunk_bytes: int, + python_bin: str, + moshi_args: list[str], + output_wav: Path, + output_text: Path | None, +): + framed_in = bytearray() + for i in range(0, len(pcm_bytes), chunk_bytes): + framed_in.extend(encode_packet(KIND_AUDIO, pcm_bytes[i : i + chunk_bytes])) + + env = os.environ.copy() + package_root = str(Path(__file__).resolve().parents[1]) # .../personaplex/moshi + existing_pythonpath = env.get("PYTHONPATH") + env["PYTHONPATH"] = f"{package_root}:{existing_pythonpath}" if existing_pythonpath else package_root + + cmd = [python_bin, "-m", "moshi.stdio"] + moshi_args + print("running:", " ".join(cmd), file=sys.stderr) + proc = subprocess.run( + cmd, + input=bytes(framed_in), + stdout=subprocess.PIPE, + stderr=None, + env=env, + check=False, + ) + if proc.returncode != 0: + raise RuntimeError(f"moshi.stdio exited with code {proc.returncode}") + + packets = decode_packets(proc.stdout) + audio_out = bytearray() + text_out: list[str] = [] + for payload in packets: + kind = payload[0] + body = payload[1:] + if kind == KIND_AUDIO: + audio_out.extend(body) + elif kind == 0x02: + text_out.append(body.decode("utf-8", errors="replace")) + + write_wav_pcm16_mono(output_wav, bytes(audio_out), 24000) + if output_text is not None: + output_text.write_text("".join(text_out), encoding="utf-8") + + +def main() -> int: + parser = argparse.ArgumentParser(description="Test runner for moshi.stdio.") + parser.add_argument( + "--input-wav", + default=str(Path(__file__).resolve().parents[2] / "assets/test/input_assistant.wav"), + help="Input WAV (mono PCM16).", + ) + parser.add_argument("--chunk-bytes", type=int, default=1000, help="PCM bytes per audio packet.") + parser.add_argument("--out-wav", help="Optional output wav for framing roundtrip.") + parser.add_argument("--python-bin", default=sys.executable, help="Python executable for subprocess.") + parser.add_argument( + "--moshi-args", + nargs=argparse.REMAINDER, + default=[], + help="Args forwarded to `python -m moshi.stdio`.", + ) + repo_root = Path(__file__).resolve().parents[2] + parser.add_argument( + "--e2e-out-wav", + default=str(repo_root / "moshi_stdio_out.wav"), + help="Output wav for e2e.", + ) + parser.add_argument( + "--e2e-out-text", + default=str(repo_root / "moshi_stdio_out.txt"), + help="Optional output text for e2e.", + ) + args = parser.parse_args() + + pcm_bytes, sample_rate = read_wav_pcm16_mono(Path(args.input_wav)) + out_wav = Path(args.out_wav) if args.out_wav else None + framing_roundtrip(pcm_bytes, sample_rate, args.chunk_bytes, out_wav) + + if not args.moshi_args: + raise SystemExit( + "moshi args required. Example: " + "--moshi-args --voice-prompt NATM1.pt --device cuda" + ) + + output_wav = Path(args.e2e_out_wav) + output_text = Path(args.e2e_out_text) if args.e2e_out_text else None + run_e2e( + pcm_bytes=pcm_bytes, + chunk_bytes=args.chunk_bytes, + python_bin=args.python_bin, + moshi_args=args.moshi_args, + output_wav=output_wav, + output_text=output_text, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())