Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/maai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Public package exports for the Maai library."""

from maai.model import Maai
import maai.input as MaaiInput
import maai.output as MaaiOutput
from maai.util import get_available_models
from maai.util import get_available_models
23 changes: 22 additions & 1 deletion src/maai/encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Audio encoder wrappers and utilities."""

import torch
import torch.nn as nn
import einops
Expand Down Expand Up @@ -48,19 +50,30 @@ def __init__(self, load_pretrained=True, freeze=True, cpc_model=''):
self.freeze()

def get_default_conf(self):
"""Return a placeholder default configuration."""
return {""}

def freeze(self):
"""Freeze encoder parameters for inference."""
for p in self.encoder.parameters():
p.requires_grad_(False)
print(f"Froze {self.__class__.__name__}!")

def unfreeze(self):
"""Unfreeze encoder parameters for training."""
for p in self.encoder.parameters():
p.requires_grad_(True)
print(f"Trainable {self.__class__.__name__}!")

def forward(self, waveform):
"""Encode waveform frames into downsampled representations.

Args:
waveform: Input waveform tensor.

Returns:
Encoded feature tensor.
"""

if waveform.ndim < 3:
waveform = waveform.unsqueeze(1) # channel dim
Expand All @@ -85,4 +98,12 @@ def forward(self, waveform):
return z

def hash_tensor(self, tensor):
return hash(tuple(tensor.reshape(-1).tolist()))
"""Return a hash of a tensor's flattened values.

Args:
tensor: Tensor to hash.

Returns:
Integer hash of the flattened tensor values.
"""
return hash(tuple(tensor.reshape(-1).tolist()))
14 changes: 13 additions & 1 deletion src/maai/encoder_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,18 @@ def get_cnn_layer(
dilation: List[int] = [1],
activation: str = "GELU",
):
"""Build a sequential convolutional block with normalization and activation.

Args:
dim: Channel dimension for the convolution.
kernel: List of kernel sizes.
stride: List of stride values.
dilation: List of dilation values.
activation: Activation class name in torch.nn.

Returns:
Sequential convolutional block.
"""
layers = []
layers.append(Rearrange("b t d -> b d t"))
for k, s, d in zip(kernel, stride, dilation):
Expand All @@ -508,4 +520,4 @@ def get_cnn_layer(
layers.append(LayerNorm(dim))
layers.append(getattr(torch.nn, activation)())
layers.append(Rearrange("b d t -> b t d"))
return nn.Sequential(*layers)
return nn.Sequential(*layers)
92 changes: 92 additions & 0 deletions src/maai/input.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Audio input helpers for microphones, WAV files, and TCP streams."""

import socket
import pyaudio
import queue
Expand All @@ -13,6 +15,14 @@
import locale

def available_mic_devices(print_out=True):
"""Return a dictionary of available microphone devices.

Args:
print_out: Whether to print the device list to stdout.

Returns:
Mapping of device index to device info.
"""
p = pyaudio.PyAudio()
device_info = {}

Expand Down Expand Up @@ -53,6 +63,7 @@ def available_mic_devices(print_out=True):
return device_info

class Base:
"""Base class for audio sources with subscription support."""
FRAME_SIZE = 160
SAMPLING_RATE = 16000
def __init__(self):
Expand All @@ -61,25 +72,51 @@ def __init__(self):
self._is_thread_started = False

def subscribe(self):
"""Create and register a subscriber queue.

Returns:
Queue receiving audio frames.
"""
q = queue.Queue()
with self._lock:
self._subscriber_queues.append(q)
return q

def _put_to_all_queues(self, data):
"""Fan out audio data to all subscriber queues.

Args:
data: Audio frame data to broadcast.
"""
# Put data into all subscriber queues and the default queue
with self._lock:
for q in self._subscriber_queues:
q.put(data)

def get_audio_data(self, q=None):
"""Get the next audio frame from a subscribed queue.

Args:
q: Subscriber queue to read from.

Returns:
Audio frame data.
"""
return q.get()

def _get_queue_size(self):
"""Return the total queued frame count across subscribers."""
with self._lock:
return sum([len(q.queue) for q in self._subscriber_queues])

class Mic(Base):
"""Microphone-backed audio source.

Args:
audio_gain: Gain multiplier applied to audio frames.
mic_device_index: Optional device index to capture from.
device_name: Optional device name substring to select.
"""

def __init__(self, audio_gain=1.0, mic_device_index=-1, device_name=None):

Expand Down Expand Up @@ -113,6 +150,7 @@ def __init__(self, audio_gain=1.0, mic_device_index=-1, device_name=None):
start=False)

def _read_mic(self):
"""Read frames from the microphone and broadcast to subscribers."""

self.stream.start_stream()

Expand All @@ -123,11 +161,18 @@ def _read_mic(self):
self._put_to_all_queues(d)

def start(self):
"""Start the microphone read thread."""
if not self._is_thread_started:
threading.Thread(target=self._read_mic, daemon=True).start()
self._is_thread_started = True

class Wav(Base):
"""WAV-file-backed audio source.

Args:
wav_file_path: Path to the WAV file.
audio_gain: Gain multiplier applied to audio frames.
"""
def __init__(self, wav_file_path, audio_gain=1.0):
super().__init__()
self.wav_file_path = wav_file_path
Expand All @@ -150,6 +195,7 @@ def __init__(self, wav_file_path, audio_gain=1.0):
self.raw_wav_queue.put(d)

def _read_wav(self):
"""Stream WAV frames in real time and broadcast to subscribers."""
start_time = time.time()
frame_duration = self.FRAME_SIZE / self.SAMPLING_RATE
pygame.mixer.init(frequency=16000, size=-16, channels=1, buffer=512)
Expand All @@ -171,11 +217,21 @@ def _read_wav(self):
self._put_to_all_queues(data)

def start(self):
"""Start the WAV streaming thread."""
if not self._is_thread_started:
threading.Thread(target=self._read_wav, daemon=True).start()
self._is_thread_started = True

class Tcp(Base):
"""TCP audio receiver that provides frames to subscribers.

Args:
ip: IP address to bind/connect.
port: TCP port to bind/connect.
audio_gain: Gain multiplier applied to audio frames.
recv_float32: Whether to read 4-byte float frames.
client_mode: Whether to connect as a client.
"""
def __init__(self, ip='127.0.0.1', port=8501, audio_gain=1.0,recv_float32=False, client_mode=False):
super().__init__()
self.ip = ip
Expand All @@ -189,6 +245,7 @@ def __init__(self, ip='127.0.0.1', port=8501, audio_gain=1.0,recv_float32=False,
self.client_mode = client_mode # クライアントモードオプション

def _server(self):
"""Start the TCP server for incoming audio streams."""
while True:
if self.conn is not None:
time.sleep(0.1)
Expand All @@ -206,6 +263,7 @@ def _server(self):
continue

def _client(self):
"""Start the TCP client for audio streams."""
while True:
if self.conn is not None:
time.sleep(0.1)
Expand All @@ -222,6 +280,7 @@ def _client(self):
continue

def _process(self):
"""Receive audio frames over TCP and broadcast to subscribers."""
import struct
while True:
try:
Expand Down Expand Up @@ -277,11 +336,13 @@ def _process(self):
continue

def start(self):
"""Start the audio processing thread."""
if not self._is_thread_started_process:
threading.Thread(target=self._process, daemon=True).start()
self._is_thread_started_process = True

def start_server(self):
"""Start the TCP server/client thread."""
if not self._is_thread_started_server:
if self.client_mode:
threading.Thread(target=self._client, daemon=True).start()
Expand All @@ -290,15 +351,29 @@ def start_server(self):
self._is_thread_started_server = True

def _send_data_manual(self, data):
"""Send raw audio bytes to the connected peer.

Args:
data: Raw byte payload to send.
"""
if self.conn is None:
raise ConnectionError("No connection established. Call start_server() first.")
self.conn.send(data)

def is_connected(self):
"""Return True when a TCP connection is established."""
return self.conn is not None and self.addr is not None


class TcpMic(Base):
"""Microphone source that sends frames to a TCP server.

Args:
server_ip: Target server IP.
port: Target server port.
audio_gain: Gain multiplier applied to audio frames.
mic_device_index: Device index to capture from.
"""
def __init__(self, server_ip='127.0.0.1', port=8501, audio_gain=1.0, mic_device_index=0):
self.ip = server_ip
self.port = port
Expand All @@ -313,11 +388,13 @@ def __init__(self, server_ip='127.0.0.1', port=8501, audio_gain=1.0, mic_device_
input_device_index=self.mic_device_index)

def connect_server(self):
"""Connect to the TCP server."""
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((self.ip, self.port))
print('[CLIENT] Connected to the server')

def _start_client(self):
"""Continuously send microphone frames to the server."""
while True:
try:
self.connect_server()
Expand Down Expand Up @@ -349,22 +426,31 @@ def _start_client(self):
time.sleep(0.5)

def start(self):
"""Start the TCP microphone client thread."""
threading.Thread(target=self._start_client, daemon=True).start()


class TcpChunk(Base):
"""TCP receiver that processes raw audio chunks.

Args:
server_ip: Target server IP.
port: Target server port.
"""
def __init__(self, server_ip='127.0.0.1', port=8501):
self.ip = server_ip
self.port = port
self.chunk_size = 1024
self.sock = None

def connect_server(self):
"""Connect to the TCP server."""
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((self.ip, self.port))
print('[CLIENT] Connected to the server')

def _start_client(self):
"""Receive raw chunks and process them."""
while True:
try:
self.connect_server()
Expand Down Expand Up @@ -392,9 +478,15 @@ def _start_client(self):
time.sleep(0.5)

def start(self):
"""Start the TCP chunk receiver thread."""
threading.Thread(target=self._start_client, daemon=True).start()

def put_chunk(self, chunk_data):
"""Send a chunk of audio data to the connected server.

Args:
chunk_data: Audio frame array to send.
"""
if self.sock is not None:
data_sent = util.conv_floatarray_2_byte(chunk_data)
self.sock.sendall(data_sent)
Expand Down
Loading