diff --git a/src/maai/__init__.py b/src/maai/__init__.py index 26fe9a2..763c9cd 100644 --- a/src/maai/__init__.py +++ b/src/maai/__init__.py @@ -1,4 +1,6 @@ +"""Public package exports for the Maai library.""" + from maai.model import Maai import maai.input as MaaiInput import maai.output as MaaiOutput -from maai.util import get_available_models \ No newline at end of file +from maai.util import get_available_models diff --git a/src/maai/encoder.py b/src/maai/encoder.py index 0b8b23c..74ced21 100644 --- a/src/maai/encoder.py +++ b/src/maai/encoder.py @@ -1,3 +1,5 @@ +"""Audio encoder wrappers and utilities.""" + import torch import torch.nn as nn import einops @@ -48,19 +50,30 @@ def __init__(self, load_pretrained=True, freeze=True, cpc_model=''): self.freeze() def get_default_conf(self): + """Return a placeholder default configuration.""" return {""} def freeze(self): + """Freeze encoder parameters for inference.""" for p in self.encoder.parameters(): p.requires_grad_(False) print(f"Froze {self.__class__.__name__}!") def unfreeze(self): + """Unfreeze encoder parameters for training.""" for p in self.encoder.parameters(): p.requires_grad_(True) print(f"Trainable {self.__class__.__name__}!") def forward(self, waveform): + """Encode waveform frames into downsampled representations. + + Args: + waveform: Input waveform tensor. + + Returns: + Encoded feature tensor. + """ if waveform.ndim < 3: waveform = waveform.unsqueeze(1) # channel dim @@ -85,4 +98,12 @@ def forward(self, waveform): return z def hash_tensor(self, tensor): - return hash(tuple(tensor.reshape(-1).tolist())) \ No newline at end of file + """Return a hash of a tensor's flattened values. + + Args: + tensor: Tensor to hash. + + Returns: + Integer hash of the flattened tensor values. + """ + return hash(tuple(tensor.reshape(-1).tolist())) diff --git a/src/maai/encoder_components.py b/src/maai/encoder_components.py index 5d0162a..6c8d5c3 100644 --- a/src/maai/encoder_components.py +++ b/src/maai/encoder_components.py @@ -500,6 +500,18 @@ def get_cnn_layer( dilation: List[int] = [1], activation: str = "GELU", ): + """Build a sequential convolutional block with normalization and activation. + + Args: + dim: Channel dimension for the convolution. + kernel: List of kernel sizes. + stride: List of stride values. + dilation: List of dilation values. + activation: Activation class name in torch.nn. + + Returns: + Sequential convolutional block. + """ layers = [] layers.append(Rearrange("b t d -> b d t")) for k, s, d in zip(kernel, stride, dilation): @@ -508,4 +520,4 @@ def get_cnn_layer( layers.append(LayerNorm(dim)) layers.append(getattr(torch.nn, activation)()) layers.append(Rearrange("b d t -> b t d")) - return nn.Sequential(*layers) \ No newline at end of file + return nn.Sequential(*layers) diff --git a/src/maai/input.py b/src/maai/input.py index 118dfe7..ea3ff3d 100644 --- a/src/maai/input.py +++ b/src/maai/input.py @@ -1,3 +1,5 @@ +"""Audio input helpers for microphones, WAV files, and TCP streams.""" + import socket import pyaudio import queue @@ -13,6 +15,14 @@ import locale def available_mic_devices(print_out=True): + """Return a dictionary of available microphone devices. + + Args: + print_out: Whether to print the device list to stdout. + + Returns: + Mapping of device index to device info. + """ p = pyaudio.PyAudio() device_info = {} @@ -53,6 +63,7 @@ def available_mic_devices(print_out=True): return device_info class Base: + """Base class for audio sources with subscription support.""" FRAME_SIZE = 160 SAMPLING_RATE = 16000 def __init__(self): @@ -61,25 +72,51 @@ def __init__(self): self._is_thread_started = False def subscribe(self): + """Create and register a subscriber queue. + + Returns: + Queue receiving audio frames. + """ q = queue.Queue() with self._lock: self._subscriber_queues.append(q) return q def _put_to_all_queues(self, data): + """Fan out audio data to all subscriber queues. + + Args: + data: Audio frame data to broadcast. + """ # Put data into all subscriber queues and the default queue with self._lock: for q in self._subscriber_queues: q.put(data) def get_audio_data(self, q=None): + """Get the next audio frame from a subscribed queue. + + Args: + q: Subscriber queue to read from. + + Returns: + Audio frame data. + """ return q.get() def _get_queue_size(self): + """Return the total queued frame count across subscribers.""" with self._lock: return sum([len(q.queue) for q in self._subscriber_queues]) class Mic(Base): + """Microphone-backed audio source. + + Args: + audio_gain: Gain multiplier applied to audio frames. + mic_device_index: Optional device index to capture from. + device_name: Optional device name substring to select. + """ def __init__(self, audio_gain=1.0, mic_device_index=-1, device_name=None): @@ -113,6 +150,7 @@ def __init__(self, audio_gain=1.0, mic_device_index=-1, device_name=None): start=False) def _read_mic(self): + """Read frames from the microphone and broadcast to subscribers.""" self.stream.start_stream() @@ -123,11 +161,18 @@ def _read_mic(self): self._put_to_all_queues(d) def start(self): + """Start the microphone read thread.""" if not self._is_thread_started: threading.Thread(target=self._read_mic, daemon=True).start() self._is_thread_started = True class Wav(Base): + """WAV-file-backed audio source. + + Args: + wav_file_path: Path to the WAV file. + audio_gain: Gain multiplier applied to audio frames. + """ def __init__(self, wav_file_path, audio_gain=1.0): super().__init__() self.wav_file_path = wav_file_path @@ -150,6 +195,7 @@ def __init__(self, wav_file_path, audio_gain=1.0): self.raw_wav_queue.put(d) def _read_wav(self): + """Stream WAV frames in real time and broadcast to subscribers.""" start_time = time.time() frame_duration = self.FRAME_SIZE / self.SAMPLING_RATE pygame.mixer.init(frequency=16000, size=-16, channels=1, buffer=512) @@ -171,11 +217,21 @@ def _read_wav(self): self._put_to_all_queues(data) def start(self): + """Start the WAV streaming thread.""" if not self._is_thread_started: threading.Thread(target=self._read_wav, daemon=True).start() self._is_thread_started = True class Tcp(Base): + """TCP audio receiver that provides frames to subscribers. + + Args: + ip: IP address to bind/connect. + port: TCP port to bind/connect. + audio_gain: Gain multiplier applied to audio frames. + recv_float32: Whether to read 4-byte float frames. + client_mode: Whether to connect as a client. + """ def __init__(self, ip='127.0.0.1', port=8501, audio_gain=1.0,recv_float32=False, client_mode=False): super().__init__() self.ip = ip @@ -189,6 +245,7 @@ def __init__(self, ip='127.0.0.1', port=8501, audio_gain=1.0,recv_float32=False, self.client_mode = client_mode # クライアントモードオプション def _server(self): + """Start the TCP server for incoming audio streams.""" while True: if self.conn is not None: time.sleep(0.1) @@ -206,6 +263,7 @@ def _server(self): continue def _client(self): + """Start the TCP client for audio streams.""" while True: if self.conn is not None: time.sleep(0.1) @@ -222,6 +280,7 @@ def _client(self): continue def _process(self): + """Receive audio frames over TCP and broadcast to subscribers.""" import struct while True: try: @@ -277,11 +336,13 @@ def _process(self): continue def start(self): + """Start the audio processing thread.""" if not self._is_thread_started_process: threading.Thread(target=self._process, daemon=True).start() self._is_thread_started_process = True def start_server(self): + """Start the TCP server/client thread.""" if not self._is_thread_started_server: if self.client_mode: threading.Thread(target=self._client, daemon=True).start() @@ -290,15 +351,29 @@ def start_server(self): self._is_thread_started_server = True def _send_data_manual(self, data): + """Send raw audio bytes to the connected peer. + + Args: + data: Raw byte payload to send. + """ if self.conn is None: raise ConnectionError("No connection established. Call start_server() first.") self.conn.send(data) def is_connected(self): + """Return True when a TCP connection is established.""" return self.conn is not None and self.addr is not None class TcpMic(Base): + """Microphone source that sends frames to a TCP server. + + Args: + server_ip: Target server IP. + port: Target server port. + audio_gain: Gain multiplier applied to audio frames. + mic_device_index: Device index to capture from. + """ def __init__(self, server_ip='127.0.0.1', port=8501, audio_gain=1.0, mic_device_index=0): self.ip = server_ip self.port = port @@ -313,11 +388,13 @@ def __init__(self, server_ip='127.0.0.1', port=8501, audio_gain=1.0, mic_device_ input_device_index=self.mic_device_index) def connect_server(self): + """Connect to the TCP server.""" self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.connect((self.ip, self.port)) print('[CLIENT] Connected to the server') def _start_client(self): + """Continuously send microphone frames to the server.""" while True: try: self.connect_server() @@ -349,10 +426,17 @@ def _start_client(self): time.sleep(0.5) def start(self): + """Start the TCP microphone client thread.""" threading.Thread(target=self._start_client, daemon=True).start() class TcpChunk(Base): + """TCP receiver that processes raw audio chunks. + + Args: + server_ip: Target server IP. + port: Target server port. + """ def __init__(self, server_ip='127.0.0.1', port=8501): self.ip = server_ip self.port = port @@ -360,11 +444,13 @@ def __init__(self, server_ip='127.0.0.1', port=8501): self.sock = None def connect_server(self): + """Connect to the TCP server.""" self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.connect((self.ip, self.port)) print('[CLIENT] Connected to the server') def _start_client(self): + """Receive raw chunks and process them.""" while True: try: self.connect_server() @@ -392,9 +478,15 @@ def _start_client(self): time.sleep(0.5) def start(self): + """Start the TCP chunk receiver thread.""" threading.Thread(target=self._start_client, daemon=True).start() def put_chunk(self, chunk_data): + """Send a chunk of audio data to the connected server. + + Args: + chunk_data: Audio frame array to send. + """ if self.sock is not None: data_sent = util.conv_floatarray_2_byte(chunk_data) self.sock.sendall(data_sent) diff --git a/src/maai/model.py b/src/maai/model.py index 8271474..956a771 100644 --- a/src/maai/model.py +++ b/src/maai/model.py @@ -1,3 +1,5 @@ +"""Model orchestration for streaming VAP inference.""" + import torch import torch.nn as nn import time @@ -17,7 +19,7 @@ # from .models.vap_prompt import VapGPT_prompt class Maai(): - + """Run streaming VAP inference from paired audio sources.""" BINS_P_NOW = [0, 1] BINS_PFUTURE = [2, 3] @@ -39,6 +41,22 @@ def __init__( use_kv_cache: bool = True, local_model = None, ): + """Initialize the Maai model with audio sources and configuration. + + Args: + mode: Model mode (e.g., "vap", "vap_mc", "bc", "bc_2type", "nod"). + lang: Language code for model selection. + audio_ch1: Audio source for channel 1. + audio_ch2: Audio source for channel 2. + frame_rate: Output frame rate in Hz. + context_len_sec: Context window length in seconds. + device: Torch device to run inference on. + cpc_model: Path to CPC checkpoint. + cache_dir: Optional cache directory for model downloads. + force_download: Whether to force model downloads. + use_kv_cache: Whether to use KV cache for streaming. + local_model: Optional local model checkpoint path. + """ conf = VapConfig() @@ -160,6 +178,7 @@ def __init__( self._worker_thread = None def worker(self): + """Background worker loop that reads audio and runs inference.""" # Clear the queues at the start # This is to ensure that the queues are empty before starting the processing loop @@ -189,7 +208,7 @@ def worker(self): # self._mic2_queue.queue.clear() def start(self): - + """Start audio capture and processing threads.""" self.mic1.start() self.mic2.start() self._stop_event.clear() @@ -225,7 +244,12 @@ def stop(self, wait: bool = True, timeout: float = 2.0): pass def process(self, x1, x2): - + """Process one frame of audio from both channels. + + Args: + x1: Audio frame for channel 1. + x2: Audio frame for channel 2. + """ time_start = time.time() # Initialize buffer if empty @@ -374,6 +398,11 @@ def process(self, x1, x2): self.current_x2_audio = self.current_x2_audio[-self.frame_contxt_padding:].copy() def get_result(self): + """Return the next available inference result. + + Returns: + Result dictionary for the latest processed frame. + """ return self.result_dict_queue.get() def set_prompt_ch1(self, prompt: str): diff --git a/src/maai/models/config.py b/src/maai/models/config.py index 7f49339..7463323 100644 --- a/src/maai/models/config.py +++ b/src/maai/models/config.py @@ -1,3 +1,5 @@ +"""Configuration dataclass for VAP models.""" + from dataclasses import dataclass, field from typing import List @@ -5,6 +7,7 @@ @dataclass class VapConfig: + """Configuration parameters for VAP model components.""" sample_rate: int = 16000 frame_hz: int = 10 bin_times: List[float] = field(default_factory=lambda: BIN_TIMES) @@ -40,6 +43,15 @@ class VapConfig: @staticmethod def add_argparse_args(parser, fields_added=[]): + """Add configuration fields to an argparse parser. + + Args: + parser: argparse parser instance. + fields_added: List to append added field names to. + + Returns: + Tuple of (parser, fields_added). + """ for k, v in VapConfig.__dataclass_fields__.items(): if k == "bin_times": parser.add_argument( @@ -52,10 +64,18 @@ def add_argparse_args(parser, fields_added=[]): @staticmethod def args_to_conf(args): + """Convert argparse args into a VapConfig instance. + + Args: + args: argparse namespace with "vap_" prefixed fields. + + Returns: + VapConfig instance. + """ return VapConfig( **{ k.replace("vap_", ""): v for k, v in vars(args).items() if k.startswith("vap_") } - ) \ No newline at end of file + ) diff --git a/src/maai/models/vap.py b/src/maai/models/vap.py index 6b0e43e..a9981b6 100644 --- a/src/maai/models/vap.py +++ b/src/maai/models/vap.py @@ -1,3 +1,5 @@ +"""VAP model definition for turn-taking prediction.""" + import torch import torch.nn as nn from torch import Tensor @@ -10,11 +12,16 @@ from ..objective import ObjectiveVAP class VapGPT(nn.Module): - + """GPT-based VAP model for stereo turn-taking prediction.""" BINS_P_NOW = [0, 1] BINS_PFUTURE = [2, 3] def __init__(self, conf: Optional[VapConfig] = None): + """Initialize the VAP model and its submodules. + + Args: + conf: Optional configuration object. + """ super().__init__() if conf is None: conf = VapConfig() @@ -62,6 +69,11 @@ def __init__(self, conf: Optional[VapConfig] = None): self.vap_head = nn.Linear(conf.dim, self.objective.n_classes) def load_encoder(self, cpc_model): + """Load CPC encoders and optionally freeze them. + + Args: + cpc_model: Path to the CPC checkpoint. + """ # Audio Encoder #if self.conf.encoder_type == "cpc": @@ -90,16 +102,34 @@ def load_encoder(self, cpc_model): @property def horizon_time(self): + """Return the prediction horizon time in seconds.""" return self.objective.horizon_time def encode_audio(self, audio1: torch.Tensor, audio2: torch.Tensor) -> Tuple[Tensor, Tensor]: - + """Encode paired audio tensors into embeddings. + + Args: + audio1: Audio tensor for channel 1. + audio2: Audio tensor for channel 2. + + Returns: + Tuple of encoded tensors (x1, x2). + """ x1 = self.encoder1(audio1) # speaker 1 x2 = self.encoder2(audio2) # speaker 2 return x1, x2 def vad_loss(self, vad_output, vad): + """Compute VAD loss for the model outputs. + + Args: + vad_output: VAD logits. + vad: VAD labels. + + Returns: + Loss tensor. + """ return F.binary_cross_entropy_with_logits(vad_output, vad) def forward( @@ -171,4 +201,4 @@ def forward( ret = {"p_now": p_now, "p_future": p_future, "vad": [vad1, vad2]} - return ret, new_cache \ No newline at end of file + return ret, new_cache diff --git a/src/maai/models/vap_bc.py b/src/maai/models/vap_bc.py index d05d719..44c5349 100644 --- a/src/maai/models/vap_bc.py +++ b/src/maai/models/vap_bc.py @@ -1,3 +1,5 @@ +"""VAP model variant for backchannel prediction.""" + import torch import torch.nn as nn from torch import Tensor @@ -10,8 +12,13 @@ from ..objective import ObjectiveVAP class VapGPT_bc(nn.Module): - + """GPT-based VAP model for backchannel detection.""" def __init__(self, conf: Optional[VapConfig] = None): + """Initialize the backchannel VAP model. + + Args: + conf: Optional configuration object. + """ super().__init__() if conf is None: conf = VapConfig() @@ -55,6 +62,11 @@ def __init__(self, conf: Optional[VapConfig] = None): self.bc_head = nn.Linear(conf.dim, 1) def load_encoder(self, cpc_model): + """Load CPC encoders and optionally freeze them. + + Args: + cpc_model: Path to the CPC checkpoint. + """ # Audio Encoder self.encoder1 = EncoderCPC( @@ -79,10 +91,19 @@ def load_encoder(self, cpc_model): @property def horizon_time(self): + """Return the prediction horizon time in seconds.""" return self.objective.horizon_time def encode_audio(self, audio1: torch.Tensor, audio2: torch.Tensor) -> Tuple[Tensor, Tensor]: - + """Encode paired audio tensors into embeddings. + + Args: + audio1: Audio tensor for channel 1. + audio2: Audio tensor for channel 2. + + Returns: + Tuple of encoded tensors (x1, x2). + """ # Channel swap for temporal consistency x1 = self.encoder1(audio2) # speaker 1 (User) x2 = self.encoder2(audio1) # speaker 2 (System) @@ -90,6 +111,15 @@ def encode_audio(self, audio1: torch.Tensor, audio2: torch.Tensor) -> Tuple[Tens return x1, x2 def vad_loss(self, vad_output, vad): + """Compute VAD loss for the model outputs. + + Args: + vad_output: VAD logits. + vad: VAD labels. + + Returns: + Loss tensor. + """ return F.binary_cross_entropy_with_logits(vad_output, vad) def forward( @@ -136,4 +166,4 @@ def forward( ret = {"p_bc": p_bc} - return ret, new_cache \ No newline at end of file + return ret, new_cache diff --git a/src/maai/models/vap_bc_2type.py b/src/maai/models/vap_bc_2type.py index 38a6f00..1fc74b3 100644 --- a/src/maai/models/vap_bc_2type.py +++ b/src/maai/models/vap_bc_2type.py @@ -1,3 +1,5 @@ +"""VAP model variant for two-type backchannel prediction.""" + import torch import torch.nn as nn from torch import Tensor @@ -10,8 +12,13 @@ from ..objective import ObjectiveVAP class VapGPT_bc_2type(nn.Module): - + """GPT-based VAP model for reactive and emotional backchannels.""" def __init__(self, conf: Optional[VapConfig] = None): + """Initialize the two-type backchannel VAP model. + + Args: + conf: Optional configuration object. + """ super().__init__() if conf is None: conf = VapConfig() @@ -62,6 +69,11 @@ def __init__(self, conf: Optional[VapConfig] = None): self.bc_head = nn.Linear(conf.dim, 3) def load_encoder(self, cpc_model): + """Load CPC encoders and optionally freeze them. + + Args: + cpc_model: Path to the CPC checkpoint. + """ # Audio Encoder self.encoder1 = EncoderCPC( @@ -86,10 +98,19 @@ def load_encoder(self, cpc_model): @property def horizon_time(self): + """Return the prediction horizon time in seconds.""" return self.objective.horizon_time def encode_audio(self, audio1: torch.Tensor, audio2: torch.Tensor) -> Tuple[Tensor, Tensor]: - + """Encode paired audio tensors into embeddings. + + Args: + audio1: Audio tensor for channel 1. + audio2: Audio tensor for channel 2. + + Returns: + Tuple of encoded tensors (x1, x2). + """ # Channel swap for temporal consistency x1 = self.encoder1(audio2) # speaker 1 (User) x2 = self.encoder2(audio1) # speaker 2 (System) @@ -97,6 +118,15 @@ def encode_audio(self, audio1: torch.Tensor, audio2: torch.Tensor) -> Tuple[Tens return x1, x2 def vad_loss(self, vad_output, vad): + """Compute VAD loss for the model outputs. + + Args: + vad_output: VAD logits. + vad: VAD labels. + + Returns: + Loss tensor. + """ return F.binary_cross_entropy_with_logits(vad_output, vad) def forward( @@ -143,4 +173,4 @@ def forward( ret = {"p_bc_react": p_bc_react, "p_bc_emo": p_bc_emo} - return ret, new_cache \ No newline at end of file + return ret, new_cache diff --git a/src/maai/models/vap_nod.py b/src/maai/models/vap_nod.py index 733f7e5..3fefd99 100644 --- a/src/maai/models/vap_nod.py +++ b/src/maai/models/vap_nod.py @@ -1,3 +1,5 @@ +"""VAP model variant for nodding prediction.""" + import torch import torch.nn as nn from torch import Tensor @@ -10,7 +12,13 @@ from ..objective import ObjectiveVAP class VapGPT_nod(nn.Module): + """GPT-based VAP model for nodding detection.""" def __init__(self, conf: Optional[VapConfig] = None): + """Initialize the nodding VAP model. + + Args: + conf: Optional configuration object. + """ super().__init__() if conf is None: conf = VapConfig() @@ -62,6 +70,11 @@ def __init__(self, conf: Optional[VapConfig] = None): self.bc_head = nn.Linear(conf.dim, 1) def load_encoder(self, cpc_model): + """Load CPC encoders and optionally freeze them. + + Args: + cpc_model: Path to the CPC checkpoint. + """ # Audio Encoder self.encoder1 = EncoderCPC( @@ -86,10 +99,19 @@ def load_encoder(self, cpc_model): @property def horizon_time(self): + """Return the prediction horizon time in seconds.""" return self.objective.horizon_time def encode_audio(self, audio1: torch.Tensor, audio2: torch.Tensor) -> Tuple[Tensor, Tensor]: - + """Encode paired audio tensors into embeddings. + + Args: + audio1: Audio tensor for channel 1. + audio2: Audio tensor for channel 2. + + Returns: + Tuple of encoded tensors (x1, x2). + """ # Channel swap for temporal consistency x1 = self.encoder1(audio2) # speaker 1 (User) x2 = self.encoder2(audio1) # speaker 2 (System) @@ -97,6 +119,15 @@ def encode_audio(self, audio1: torch.Tensor, audio2: torch.Tensor) -> Tuple[Tens return x1, x2 def vad_loss(self, vad_output, vad): + """Compute VAD loss for the model outputs. + + Args: + vad_output: VAD logits. + vad: VAD labels. + + Returns: + Loss tensor. + """ return F.binary_cross_entropy_with_logits(vad_output, vad) def forward( @@ -152,4 +183,4 @@ def forward( "p_nod_long_p": p_nod_long_p, } - return ret, new_cache \ No newline at end of file + return ret, new_cache diff --git a/src/maai/models/vap_prompt.py b/src/maai/models/vap_prompt.py index 3cff0d7..c7c6a40 100644 --- a/src/maai/models/vap_prompt.py +++ b/src/maai/models/vap_prompt.py @@ -1,3 +1,5 @@ +"""VAP model variant with prompt-conditioned behavior.""" + import torch import torch.nn as nn from torch import Tensor @@ -12,7 +14,7 @@ from sentence_transformers import SentenceTransformer class VapGPT_prompt(nn.Module): - + """GPT-based VAP model augmented with text prompt embeddings.""" BINS_P_NOW = [0, 1] BINS_PFUTURE = [2, 3] @@ -20,7 +22,11 @@ class VapGPT_prompt(nn.Module): prompt_model_name = "cl-nagoya/ruri-v3-pt-30m" def __init__(self, conf: Optional[VapConfig] = None): - + """Initialize the prompt-conditioned VAP model. + + Args: + conf: Optional configuration object. + """ super().__init__() # print this model is a beta version @@ -83,6 +89,11 @@ def __init__(self, conf: Optional[VapConfig] = None): self.set_prompt_ch2("発話前に少し間を取り、考えてから丁寧に話し始めてください。応答は急がず、落ち着いたテンポを意識してください。") def load_encoder(self, cpc_model): + """Load CPC encoders and optionally freeze them. + + Args: + cpc_model: Path to the CPC checkpoint. + """ # Audio Encoder #if self.conf.encoder_type == "cpc": @@ -111,19 +122,43 @@ def load_encoder(self, cpc_model): @property def horizon_time(self): + """Return the prediction horizon time in seconds.""" return self.objective.horizon_time def encode_audio(self, audio1: torch.Tensor, audio2: torch.Tensor) -> Tuple[Tensor, Tensor]: - + """Encode paired audio tensors into embeddings. + + Args: + audio1: Audio tensor for channel 1. + audio2: Audio tensor for channel 2. + + Returns: + Tuple of encoded tensors (x1, x2). + """ x1 = self.encoder1(audio1) # speaker 1 x2 = self.encoder2(audio2) # speaker 2 return x1, x2 def vad_loss(self, vad_output, vad): + """Compute VAD loss for the model outputs. + + Args: + vad_output: VAD logits. + vad: VAD labels. + + Returns: + Loss tensor. + """ return F.binary_cross_entropy_with_logits(vad_output, vad) def set_prompt_ch1(self, prompt: str, device: torch.device = torch.device('cpu')): + """Set the prompt embedding for channel 1. + + Args: + prompt: Prompt text for channel 1. + device: Torch device to move embeddings to. + """ embedding_ch1_ = self.prompt_embedding_model.encode([prompt], normalize_embeddings=True)[0] self.embedding_ch1 = torch.tensor(embedding_ch1_).unsqueeze(0) @@ -134,6 +169,12 @@ def set_prompt_ch1(self, prompt: str, device: torch.device = torch.device('cpu') # input("Press Enter to continue...") def set_prompt_ch2(self, prompt: str, device: torch.device = torch.device('cpu')): + """Set the prompt embedding for channel 2. + + Args: + prompt: Prompt text for channel 2. + device: Torch device to move embeddings to. + """ embedding_ch2_ = self.prompt_embedding_model.encode([prompt], normalize_embeddings=True)[0] self.embedding_ch2 = torch.tensor(embedding_ch2_).unsqueeze(0) diff --git a/src/maai/modules.py b/src/maai/modules.py index acf6ee6..d386d19 100644 --- a/src/maai/modules.py +++ b/src/maai/modules.py @@ -1,3 +1,5 @@ +"""Transformer building blocks used by the VAP models.""" + import math import torch import torch.nn as nn @@ -13,6 +15,18 @@ def ffn_block( dropout: float = 0.0, bias: bool = False, ) -> nn.Sequential: + """Create a feed-forward network block used in transformers. + + Args: + din: Input feature dimension. + dff: Hidden feed-forward dimension. + activation: Activation class name in torch.nn. + dropout: Dropout probability. + bias: Whether to use bias in linear layers. + + Returns: + Sequential feed-forward block. + """ return nn.Sequential( nn.Linear(din, dff, bias=bias), getattr(nn, activation)(), @@ -121,6 +135,7 @@ def forward( class MultiHeadAttentionAlibi(MultiHeadAttention): + """Multi-head attention with ALiBi positional bias.""" def __init__(self, dim: int, num_heads: int, dropout: float, bias: bool = False, context_limit: int = -1): super().__init__(dim, num_heads, dropout, bias) # self.m = torch.tensor(MultiHeadAttentionAlibi.get_slopes(num_heads)) @@ -306,6 +321,7 @@ def forward( class TransformerStereoLayer(TransformerLayer): + """Transformer layer that processes two streams with cross-attention.""" def forward( self, x1: torch.Tensor, @@ -420,6 +436,7 @@ def forward( class GPTStereo(GPT): + """Stereo GPT that performs cross-attention between two streams.""" def _build_layers(self): layers = [] for _ in range(self.num_layers): @@ -568,6 +585,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: def test_gpt(): + """Quick sanity check for GPT attention outputs.""" model = GPT(dim=256, dff_k=3, num_layers=4, num_heads=8) x = torch.rand((4, 20, model.dim)) with torch.no_grad(): diff --git a/src/maai/objective.py b/src/maai/objective.py index 5a4c264..d542f7e 100644 --- a/src/maai/objective.py +++ b/src/maai/objective.py @@ -1,3 +1,5 @@ +"""Objective functions and label utilities for VAP training.""" + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,10 +10,26 @@ def bin_times_to_frames(bin_times: List[float], frame_hz: int) -> List[int]: + """Convert bin durations (seconds) into frame counts. + + Args: + bin_times: List of bin durations in seconds. + frame_hz: Frame rate in Hz. + + Returns: + List of bin lengths in frames. + """ return (torch.tensor(bin_times) * frame_hz).long().tolist() class ProjectionWindow: + """Extract projected activity windows for future bins. + + Args: + bin_times: List of bin durations in seconds. + frame_hz: Frame rate in Hz. + threshold_ratio: Threshold for bin activation. + """ def __init__( self, bin_times: List = [0.2, 0.4, 0.6, 0.8], @@ -29,6 +47,7 @@ def __init__( self.horizon = sum(self.bin_frames) def __repr__(self) -> str: + """Return a readable representation of the projection window.""" s = f"{self.__class__.__name__}(\n" s += f" bin_times: {self.bin_times}\n" s += f" bin_frames: {self.bin_frames}\n" @@ -72,11 +91,20 @@ def projection_bins(self, projection_window: Tensor) -> Tensor: return torch.stack(v_bins, dim=-1) # (*, t, c, n_bins) def __call__(self, va: Tensor) -> Tensor: + """Project voice activity and return bin indicators. + + Args: + va: Voice activity tensor of shape (B, N, C). + + Returns: + Bin indicator tensor. + """ projection_windows = self.projection(va) return self.projection_bins(projection_windows) class Codebook(nn.Module): + """Codebook for quantizing projection windows to discrete classes.""" def __init__(self, bin_frames): super().__init__() self.bin_frames = bin_frames @@ -91,6 +119,15 @@ def __init__(self, bin_frames): self.emb.weight.requires_grad_(False) def single_idx_to_onehot(self, idx: int, d: int = 8) -> Tensor: + """Convert an integer index to a binary one-hot vector. + + Args: + idx: Integer index to encode. + d: Number of bits in the output vector. + + Returns: + One-hot tensor of length d. + """ assert idx < 2 ** d, "must be possible with {d} binary digits" z = torch.zeros(d) b = bin(idx).replace("0b", "") @@ -139,14 +176,37 @@ def encode(self, x: Tensor) -> Tensor: return embed_ind def decode(self, idx: Tensor): + """Decode codebook indices back into bin activations. + + Args: + idx: Codebook indices. + + Returns: + Decoded bin activations. + """ v = self.emb(idx) return rearrange(v, "... (c b) -> ... c b", c=2) def forward(self, projection_windows: Tensor): + """Encode projection windows into codebook indices. + + Args: + projection_windows: Projection window tensor. + + Returns: + Codebook indices. + """ return self.encode(projection_windows) class ObjectiveVAP(nn.Module): + """Loss utilities and label extraction for VAP training. + + Args: + bin_times: List of bin durations in seconds. + frame_hz: Frame rate in Hz. + threshold_ratio: Threshold for bin activation. + """ def __init__( self, bin_times: List[float] = [0.2, 0.4, 0.6, 0.8], @@ -169,6 +229,7 @@ def __init__( self.lid_n_classes = 3 def __repr__(self): + """Return a readable representation of the objective configuration.""" s = str(self.__class__.__name__) s += f"\n{self.codebook}" s += f"\n{self.projection_window_extractor}" @@ -190,6 +251,17 @@ def probs_next_speaker_aggregate( to_bin: int = 3, scale_with_bins: bool = False, ) -> Tensor: + """Aggregate probabilities over bins for next-speaker metrics. + + Args: + probs: Probability tensor of shape (B, N, C). + from_bin: Starting bin index. + to_bin: Ending bin index. + scale_with_bins: Whether to scale by bin sizes. + + Returns: + Aggregated probabilities. + """ assert ( probs.ndim == 3 ), f"Expected probs of shape (B, n_frames, n_classes) but got {probs.shape}" @@ -206,14 +278,38 @@ def probs_next_speaker_aggregate( return p_all def window_to_win_dialog_states(self, wins): + """Compute dialog state counts from projection windows. + + Args: + wins: Projection window tensor. + + Returns: + Dialog state counts. + """ return (wins.sum(-1) > 0).sum(-1) def get_labels(self, va: Tensor) -> Tensor: + """Return discrete labels for voice activity inputs. + + Args: + va: Voice activity tensor. + + Returns: + Label indices tensor. + """ projection_windows = self.projection_window_extractor(va).type(va.dtype) idx = self.codebook(projection_windows) return idx def get_labels_bc(self, bc_frame: Tensor) -> Tensor: + """Return backchannel labels aligned to projection windows. + + Args: + bc_frame: Backchannel frame tensor. + + Returns: + Backchannel label tensor. + """ # # bc_frame: (B, N_FRAMES) @@ -236,6 +332,14 @@ def get_labels_bc(self, bc_frame: Tensor) -> Tensor: return bc_projection_frame def get_da_labels(self, va: Tensor) -> Tuple[Tensor, Tensor]: + """Return labels and dialog state counts for a sequence. + + Args: + va: Voice activity tensor. + + Returns: + Tuple of (labels, dialog state counts). + """ projection_windows = self.projection_window_extractor(va).type(va.dtype) idx = self.codebook(projection_windows) ds = self.window_to_win_dialog_states(projection_windows) @@ -244,6 +348,16 @@ def get_da_labels(self, va: Tensor) -> Tuple[Tensor, Tensor]: def loss_vap( self, logits: Tensor, labels: Tensor, reduction: str = "mean" ) -> Tensor: + """Compute the VAP classification loss. + + Args: + logits: Logits tensor of shape (B, N, C). + labels: Label tensor of shape (B, N). + reduction: Reduction mode for loss. + + Returns: + Loss tensor. + """ assert ( logits.ndim == 3 ), f"Exptected logits of shape (B, N_FRAMES, N_CLASSES) but got {logits.shape}" @@ -269,6 +383,16 @@ def loss_vap( def loss_lid( self, logits: Tensor, labels: Tensor, reduction: str = "mean" ) -> Tensor: + """Compute the language ID loss. + + Args: + logits: Logits tensor of shape (B, N, C). + labels: Label tensor of shape (B, N). + reduction: Reduction mode for loss. + + Returns: + Loss tensor. + """ assert ( logits.ndim == 3 ), f"Exptected logits of shape (B, N_FRAMES, N_CLASSES) but got {logits.shape}" @@ -293,13 +417,41 @@ def loss_lid( return loss def loss_bc(self, bc_output, bc_label, bc_positive_weight=1.0): + """Compute the backchannel loss. + + Args: + bc_output: Backchannel logits. + bc_label: Backchannel labels. + bc_positive_weight: Positive class weight. + + Returns: + Loss tensor. + """ return F.binary_cross_entropy_with_logits(bc_output, bc_label, pos_weight=torch.tensor([bc_positive_weight], device=bc_output.device)) def loss_vad(self, vad_output, vad): + """Compute the VAD loss for stereo inputs. + + Args: + vad_output: VAD logits. + vad: VAD labels. + + Returns: + Loss tensor. + """ n = vad_output.shape[-2] return F.binary_cross_entropy_with_logits(vad_output, vad[:, :n]) def loss_vad_mono(self, vad_output, vad): + """Compute the VAD loss for mono inputs. + + Args: + vad_output: VAD logits. + vad: VAD labels. + + Returns: + Loss tensor. + """ n = vad_output.shape[-2] v = vad[:, :n, 1] # print(torch.squeeze(vad_output)) @@ -350,6 +502,17 @@ def extract_prediction_and_targets( events: Dict[str, List[List[Tuple[int, int, int]]]], device=None, ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: + """Extract prediction and target tensors for evaluation metrics. + + Args: + p_now: Short-term probabilities. + p_fut: Long-term probabilities. + events: Event indices for evaluation. + device: Optional device for outputs. + + Returns: + Tuple of (predictions, targets). + """ batch_size = len(events["hold"]) preds = {"hs": [], "hs2": [], "pred_shift": [], "pred_shift2": [], "ls": [], "pred_backchannel": [], "pred_backchannel2": [], "lid": []} diff --git a/src/maai/output.py b/src/maai/output.py index 12ecb9a..453419a 100644 --- a/src/maai/output.py +++ b/src/maai/output.py @@ -1,3 +1,5 @@ +"""Output utilities for rendering and transmitting model results.""" + import sys import math import time @@ -13,13 +15,29 @@ import matplotlib.colors as mcolors def _draw_bar(value: float, length: int = 30) -> str: - """基本的なバーグラフを描画""" + """Draw a basic bar graph representation. + + Args: + value: Normalized value between 0 and 1. + length: Total bar length in characters. + + Returns: + Rendered bar string. + """ bar_len = min(length, max(0, int(value * length))) return '█' * bar_len + '-' * (length - bar_len) def _draw_symmetric_bar(value: float, length: int = 30) -> str: - """対称的なバーグラフを描画(-1.0から1.0の範囲)""" + """Draw a symmetric bar graph in the range -1.0 to 1.0. + + Args: + value: Value in the range -1.0 to 1.0. + length: Total bar length in characters. + + Returns: + Rendered bar string. + """ max_len = length // 2 value = max(-1.0, min(1.0, value)) if value >= 0: @@ -31,7 +49,15 @@ def _draw_symmetric_bar(value: float, length: int = 30) -> str: def _draw_balance_bar(value: float, length: int = 30) -> str: - """0.5を中心としたバランスバーを描画""" + """Draw a balance bar centered at 0.5. + + Args: + value: Value in the range 0.0 to 1.0. + length: Total bar length in characters. + + Returns: + Rendered bar string. + """ # 2チャネルの場合は1チャネル目のデータを渡すこと max_len = length // 2 value = max(0.0, min(1.0, value)) @@ -47,14 +73,29 @@ def _draw_balance_bar(value: float, length: int = 30) -> str: def _rms(values: Union[List[float], tuple]) -> float: - """RMS値を計算""" + """Compute the RMS value of a sequence. + + Args: + values: Sequence of numeric values. + + Returns: + RMS value. + """ if not values: return 0.0 return math.sqrt(sum(x * x for x in values) / len(values)) def _format_value(value: Any, max_length: int = 50) -> str: - """値を適切な形式でフォーマット""" + """Format a value for display in a compact form. + + Args: + value: Value to format. + max_length: Maximum string length before truncation. + + Returns: + Formatted string representation. + """ if isinstance(value, (list, tuple)): if len(value) > 0: if isinstance(value[0], (int, float)): @@ -80,7 +121,17 @@ def _format_value(value: Any, max_length: int = 50) -> str: def _get_bar_for_value(key: str, value: Any, bar_length: int = 30, bar_type: str = "normal") -> tuple[str, float]: - """キーと値に基づいて適切なバーを選択""" + """Select the appropriate bar visualization for a value. + + Args: + key: Result key name. + value: Value to visualize. + bar_length: Total bar length in characters. + bar_type: Visualization style ("normal" or "balance"). + + Returns: + Tuple of (bar string, numeric value used). + """ if isinstance(value, (list, tuple)): if len(value) > 2: if isinstance(value[0], (int, float)): @@ -102,7 +153,7 @@ def _get_bar_for_value(key: str, value: Any, bar_length: int = 30, bar_type: str class ConsoleBar: """ - maai.get_result()の内容をバーグラフで可視化するクラス + Render maai.get_result() output as a console bar chart. """ def __init__(self, bar_length: int = 30, bar_type: str = "normal"): self.bar_length = bar_length @@ -110,6 +161,11 @@ def __init__(self, bar_length: int = 30, bar_type: str = "normal"): self._first = True def update(self, result: Dict[str, Any]): + """Render an updated bar chart for a result dictionary. + + Args: + result: Result dictionary from the model. + """ if self._first: sys.stdout.write("\x1b[2J") # 初期クリア self._first = False @@ -155,6 +211,7 @@ def update(self, result: Dict[str, Any]): print("-" * (self.bar_length + 30)) class TcpReceiver: + """Receive VAP results over TCP and expose a queue interface.""" def __init__(self, ip, port, mode): self.ip = ip self.port = port @@ -163,6 +220,14 @@ def __init__(self, ip, port, mode): self.result_queue = queue.Queue() def _bytearray_2_vapresult(self, data: bytes) -> Dict[str, Any]: + """Decode a byte payload into a VAP result dict. + + Args: + data: Serialized result payload. + + Returns: + Decoded result dictionary. + """ if self.mode in ['vap', 'vap_mc', 'vap_prompt']: vap_result = util.conv_bytearray_2_vapresult(data) elif self.mode == 'bc_2type': @@ -174,11 +239,13 @@ def _bytearray_2_vapresult(self, data: bytes) -> Dict[str, Any]: return vap_result def connect_server(self): + """Connect to the TCP result server.""" self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.connect((self.ip, self.port)) print('[CLIENT] Connected to the server') def _start_client(self): + """Continuously receive results and enqueue them.""" while True: try: self.connect_server() @@ -209,12 +276,15 @@ def _start_client(self): time.sleep(0.5) def start(self): + """Start the TCP receiver thread.""" threading.Thread(target=self._start_client, daemon=True).start() def get_result(self): + """Return the next received result.""" return self.result_queue.get() class TcpTransmitter: + """Transmit VAP results over TCP to a client.""" def __init__(self, ip, port, mode): self.ip = ip self.port = port @@ -222,6 +292,14 @@ def __init__(self, ip, port, mode): self.result_queue = queue.Queue() def _vapresult_2_bytearray(self, result_dict: Dict[str, Any]) -> bytes: + """Encode a result dictionary to bytes for transmission. + + Args: + result_dict: Result dictionary to encode. + + Returns: + Serialized byte buffer. + """ if self.mode in ['vap', 'vap_mc']: data_sent = util.conv_vapresult_2_bytearray(result_dict) elif self.mode == 'bc_2type': @@ -233,6 +311,7 @@ def _vapresult_2_bytearray(self, result_dict: Dict[str, Any]) -> bytes: return data_sent def _start_server(self): + """Start the TCP server loop for sending results.""" while True: try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -256,14 +335,20 @@ def _start_server(self): continue def start_server(self): + """Launch the TCP server thread.""" threading.Thread(target=self._start_server, daemon=True).start() def update(self, result: Dict[str, Any]): + """Queue a result for transmission. + + Args: + result: Result dictionary to send. + """ self.result_queue.put(result) # 新規追加: GUIでバーグラフを表示するクラス class GuiBar: - """matplotlibを用いて結果をバーグラフでGUI表示するクラス""" + """Show results as a GUI bar chart using matplotlib.""" def __init__(self, bar_type: str = "normal"): self.bar_type = bar_type self.plt = plt @@ -277,7 +362,11 @@ def __init__(self, bar_type: str = "normal"): sns.set_theme(style="whitegrid") def update(self, result: Dict[str, Any]): - """resultのキーと値をバーグラフで更新表示する""" + """Update the bar chart using a result dictionary. + + Args: + result: Result dictionary from the model. + """ labels = [] values = [] for key, value in result.items(): @@ -319,6 +408,15 @@ def update(self, result: Dict[str, Any]): self.plt.pause(0.001) class GuiPlot: + """Plot waveform and probability traces in a GUI. + + Args: + shown_context_sec: Seconds of history to display. + frame_rate: Frame rate used for model outputs. + sample_rate: Sample rate of waveform inputs. + figsize: Matplotlib figure size. + use_fixed_draw_rate: Whether to throttle redraws. + """ def __init__(self, shown_context_sec: int = 10, frame_rate: int = 10, sample_rate: int = 16000, figsize=(14, 10), use_fixed_draw_rate: bool = True): self.figsize = figsize self.shown_context_sec = shown_context_sec @@ -338,6 +436,11 @@ def __init__(self, shown_context_sec: int = 10, frame_rate: int = 10, sample_rat self._last_draw_time = 0.0 def _init_fig(self, result: Dict[str, any]): + """Initialize the matplotlib figure layout for result keys. + + Args: + result: Initial result dictionary. + """ special_keys = ['x1', 'x2', 'p_now', 'p_future', 'vad'] self.keys = [k for k in special_keys if k in result] + [k for k in result.keys() if k not in special_keys and k != 't'] n = len(self.keys) @@ -449,6 +552,11 @@ def _init_fig(self, result: Dict[str, any]): self.initialized = True def update(self, result: Dict[str, any]): + """Update all plots with the latest result values. + + Args: + result: Result dictionary from the model. + """ import time draw = True if self.use_fixed_draw_rate: @@ -551,4 +659,4 @@ def update(self, result: Dict[str, any]): self.lines[key][0].set_height(v) if draw: self.fig.canvas.draw_idle() - self.fig.canvas.flush_events() \ No newline at end of file + self.fig.canvas.flush_events() diff --git a/src/maai/util.py b/src/maai/util.py index acd7e6e..84f3a49 100644 --- a/src/maai/util.py +++ b/src/maai/util.py @@ -1,3 +1,5 @@ +"""Utility helpers for model loading and byte conversions.""" + import torch from huggingface_hub import hf_hub_download, list_repo_files @@ -32,7 +34,20 @@ } def load_vap_model(mode: str, frame_rate: int, context_len_sec: float, language: str = "jp", device: str = "cpu", cache_dir: str = None, force_download: bool = False): - + """Download and load a VAP model state dict from Hugging Face. + + Args: + mode: Model family (e.g., "vap", "vap_mc", "bc", "bc_2type", "nod"). + frame_rate: Frame rate in Hz used to select the checkpoint. + context_len_sec: Context length in seconds used to select the checkpoint. + language: Language code for the model variant. + device: Torch device for loading the state dict. + cache_dir: Optional Hugging Face cache directory. + force_download: Whether to force downloading weights. + + Returns: + Loaded PyTorch state dict. + """ if mode == "vap": if language == "jp": repo_id = repo_ids["vap_jp"] @@ -180,6 +195,11 @@ def load_vap_model(mode: str, frame_rate: int, context_len_sec: float, language: return sd def get_available_models(): + """Return available model files per repository. + + Returns: + Mapping from repo id to a list of checkpoint filenames. + """ available_models = {} for repo_id in repo_ids.values(): files = list_repo_files(repo_id) @@ -195,7 +215,15 @@ def get_available_models(): # def conv_2int16_2_byte(val1, val2): - + """Convert two int16 values to a 4-byte little-endian buffer. + + Args: + val1: First int16 value. + val2: Second int16 value. + + Returns: + Combined byte buffer. + """ b1 = val1.to_bytes(2, BYTE_ORDER) b2 = val2.to_bytes(2, BYTE_ORDER) @@ -209,7 +237,15 @@ def conv_2int16_2_byte(val1, val2): return b def conv_2int16array_2_bytearray(arr1, arr2): - + """Convert two int16 arrays into a concatenated byte array. + + Args: + arr1: First int16 array. + arr2: Second int16 array. + + Returns: + Concatenated byte buffer. + """ if len(arr1) != len(arr2): raise ValueError('Two arrays must have the same length') @@ -225,7 +261,15 @@ def conv_2int16array_2_bytearray(arr1, arr2): # def conv_2float_2_byte(val1, val2): - + """Convert two float64 values into a little-endian byte buffer. + + Args: + val1: First float value. + val2: Second float value. + + Returns: + Combined byte buffer. + """ b1 = struct.pack(' Byte # def conv_vapresult_2_bytearray(vap_result): - + """Serialize a VAP result dictionary to bytes. + + Args: + vap_result: Result dictionary containing arrays and probabilities. + + Returns: + Serialized byte buffer. + """ b = b'' #print(type(vap_result['t'])) b += struct.pack(' VAP result # def conv_bytearray_2_vapresult(barr): - + """Deserialize a VAP result dictionary from bytes. + + Args: + barr: Byte buffer containing a serialized VAP result. + + Returns: + Decoded VAP result dictionary. + """ idx = 0 t = struct.unpack(' Byte # def conv_vapresult_2_bytearray_bc_2type(vap_result): - + """Serialize a VAP backchannel 2-type result dictionary to bytes. + + Args: + vap_result: Result dictionary with backchannel outputs. + + Returns: + Serialized byte buffer. + """ b = b'' #print(type(vap_result['t'])) b += struct.pack(' VAP result # def conv_bytearray_2_vapresult_bc_2type(barr): - + """Deserialize a VAP backchannel 2-type result from bytes. + + Args: + barr: Byte buffer containing a serialized result. + + Returns: + Decoded result dictionary. + """ idx = 0 t = struct.unpack('