diff --git a/moshi/moshi/models/lm.py b/moshi/moshi/models/lm.py index a64d736b..27668052 100644 --- a/moshi/moshi/models/lm.py +++ b/moshi/moshi/models/lm.py @@ -976,7 +976,7 @@ def load_voice_prompt(self, voice_prompt: str): def load_voice_prompt_embeddings(self, path: str): self.voice_prompt = path - state = torch.load(path) + state = torch.load(path, weights_only=True) self.voice_prompt_audio = None self.voice_prompt_embeddings = state["embeddings"].to(self.lm_model.device) diff --git a/moshi/moshi/models/loaders.py b/moshi/moshi/models/loaders.py index d38dcc3b..3f6c9370 100644 --- a/moshi/moshi/models/loaders.py +++ b/moshi/moshi/models/loaders.py @@ -157,7 +157,7 @@ def get_mimi(filename: str | Path, if _is_safetensors(filename): load_model(model, filename) else: - pkg = torch.load(filename, "cpu") + pkg = torch.load(filename, "cpu", weights_only=True) model.load_state_dict(pkg["model"]) model.set_num_codebooks(8) return model @@ -214,7 +214,7 @@ def get_moshi_lm( else: # torch checkpoint with open(filename, "rb") as f: - state_dict = torch.load(f, map_location="cpu") + state_dict = torch.load(f, map_location="cpu", weights_only=True) # Patch 1: expand depformer self_attn weights if needed model_sd = model.state_dict() for name, tensor in list(state_dict.items()): @@ -292,7 +292,7 @@ def _get_moshi_lm_with_offload( state_dict = load_file(filename, device="cpu") else: with open(filename, "rb") as f: - state_dict = torch.load(f, map_location="cpu") + state_dict = torch.load(f, map_location="cpu", weights_only=True) # Apply weight patches (same as non-offload path) model_sd = model.state_dict() diff --git a/moshi/moshi/offline.py b/moshi/moshi/offline.py index f690620d..29d465b1 100644 --- a/moshi/moshi/offline.py +++ b/moshi/moshi/offline.py @@ -41,6 +41,7 @@ import argparse import os +import sys import tarfile from pathlib import Path import json @@ -142,7 +143,14 @@ def _get_voice_prompt_dir(voice_prompt_dir: Optional[str], hf_repo: str) -> Opti if not voices_dir.exists(): log("info", f"extracting {voices_tgz} to {voices_dir}") with tarfile.open(voices_tgz, "r:gz") as tar: - tar.extractall(path=voices_tgz.parent) + if sys.version_info >= (3, 12): + tar.extractall(path=voices_tgz.parent, filter='data') + else: + # Safe extraction fallback for Python < 3.12 + for member in tar.getmembers(): + if member.name.startswith('/') or '..' in member.name: + raise ValueError(f"Unsafe tar member: {member.name}") + tar.extract(member, path=voices_tgz.parent) if not voices_dir.exists(): raise RuntimeError("voices.tgz did not contain a 'voices/' directory") diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 771f491d..984b7cd7 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -330,7 +330,14 @@ def _get_voice_prompt_dir(voice_prompt_dir: Optional[str], hf_repo: str) -> Opti if not voices_dir.exists(): logger.info(f"extracting {voices_tgz} to {voices_dir}") with tarfile.open(voices_tgz, "r:gz") as tar: - tar.extractall(path=voices_tgz.parent) + if sys.version_info >= (3, 12): + tar.extractall(path=voices_tgz.parent, filter='data') + else: + # Safe extraction fallback for Python < 3.12 + for member in tar.getmembers(): + if member.name.startswith('/') or '..' in member.name: + raise ValueError(f"Unsafe tar member: {member.name}") + tar.extract(member, path=voices_tgz.parent) if not voices_dir.exists(): raise RuntimeError("voices.tgz did not contain a 'voices/' directory") @@ -346,7 +353,14 @@ def _get_static_path(static: Optional[str]) -> Optional[str]: dist = dist_tgz.parent / "dist" if not dist.exists(): with tarfile.open(dist_tgz, "r:gz") as tar: - tar.extractall(path=dist_tgz.parent) + if sys.version_info >= (3, 12): + tar.extractall(path=dist_tgz.parent, filter='data') + else: + # Safe extraction fallback for Python < 3.12 + for member in tar.getmembers(): + if member.name.startswith('/') or '..' in member.name: + raise ValueError(f"Unsafe tar member: {member.name}") + tar.extract(member, path=dist_tgz.parent) return str(dist) elif static != "none": # When set to the "none" string, we don't serve any static content.