From 465efc7a8dad8846179443ce84e24f2b6319aae3 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Mon, 23 Feb 2026 14:36:30 -0800 Subject: [PATCH 1/2] fix: enable MPS device support for macOS Apple Silicon Uncomment existing MPS backend detection in torch_auto_device() and add "mps" to the DeviceString type. Change --device default from hardcoded "cuda" to None so auto-detection picks the best available backend (CUDA > MPS > CPU). The MPS code path in loaders.py (loading safetensors via CPU then moving to MPS) was already implemented but unreachable due to the server entrypoint having MPS commented out. Tested on M3 Pro (36 GB) with Python 3.12 / PyTorch 2.4.1. Co-Authored-By: Claude Opus 4.6 --- moshi/moshi/server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 771f491d..b7e4fdf2 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -52,7 +52,7 @@ logger = setup_logger(__name__) -DeviceString = Literal["cuda"] | Literal["cpu"] #| Literal["mps"] +DeviceString = Literal["cuda"] | Literal["cpu"] | Literal["mps"] def torch_auto_device(requested: Optional[DeviceString] = None) -> torch.device: """Return a torch.device based on the requested string or availability.""" @@ -60,8 +60,8 @@ def torch_auto_device(requested: Optional[DeviceString] = None) -> torch.device: return torch.device(requested) if torch.cuda.is_available(): return torch.device("cuda") - #elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - # return torch.device("mps") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") return torch.device("cpu") @@ -369,7 +369,7 @@ def main(): parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO, help="HF repo to look into, defaults PersonaPlex. " "Use this to select a different pre-trained model.") - parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.") + parser.add_argument("--device", type=str, default=None, help="Device on which to run, auto-detected if not set.") parser.add_argument("--cpu-offload", action="store_true", help="Offload LM model layers to CPU when GPU memory is insufficient. " "Requires 'accelerate' package.") From de665276b94fad20f9f79a34872a543c91194af2 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Mon, 23 Feb 2026 15:52:42 -0800 Subject: [PATCH 2/2] fix: replace index_copy_ with MPS-compatible advanced indexing `aten::index_copy.out` is not implemented for the MPS device in PyTorch 2.4. Replace with equivalent slice assignment in the KV cache which is supported on all backends. Co-Authored-By: Claude Opus 4.6 --- moshi/moshi/modules/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moshi/moshi/modules/transformer.py b/moshi/moshi/modules/transformer.py index 553a2a80..b6a3622a 100644 --- a/moshi/moshi/modules/transformer.py +++ b/moshi/moshi/modules/transformer.py @@ -265,8 +265,8 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: B, H, T, D = k.shape indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset indexes = indexes % self.capacity - self.cache[0].index_copy_(2, indexes, k) - self.cache[1].index_copy_(2, indexes, v) + self.cache[0][:, :, indexes] = k + self.cache[1][:, :, indexes] = v self.end_offset.add_(T) keys = self.cache[0]