Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion moshi/moshi/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ def load_voice_prompt(self, voice_prompt: str):

def load_voice_prompt_embeddings(self, path: str):
self.voice_prompt = path
state = torch.load(path)
state = torch.load(path, weights_only=True)

self.voice_prompt_audio = None
self.voice_prompt_embeddings = state["embeddings"].to(self.lm_model.device)
Expand Down
6 changes: 3 additions & 3 deletions moshi/moshi/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_mimi(filename: str | Path,
if _is_safetensors(filename):
load_model(model, filename)
else:
pkg = torch.load(filename, "cpu")
pkg = torch.load(filename, "cpu", weights_only=True)
model.load_state_dict(pkg["model"])
model.set_num_codebooks(8)
return model
Expand Down Expand Up @@ -214,7 +214,7 @@ def get_moshi_lm(
else:
# torch checkpoint
with open(filename, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
state_dict = torch.load(f, map_location="cpu", weights_only=True)
# Patch 1: expand depformer self_attn weights if needed
model_sd = model.state_dict()
for name, tensor in list(state_dict.items()):
Expand Down Expand Up @@ -292,7 +292,7 @@ def _get_moshi_lm_with_offload(
state_dict = load_file(filename, device="cpu")
else:
with open(filename, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
state_dict = torch.load(f, map_location="cpu", weights_only=True)

# Apply weight patches (same as non-offload path)
model_sd = model.state_dict()
Expand Down
10 changes: 9 additions & 1 deletion moshi/moshi/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import argparse
import os
import sys
import tarfile
from pathlib import Path
import json
Expand Down Expand Up @@ -142,7 +143,14 @@ def _get_voice_prompt_dir(voice_prompt_dir: Optional[str], hf_repo: str) -> Opti
if not voices_dir.exists():
log("info", f"extracting {voices_tgz} to {voices_dir}")
with tarfile.open(voices_tgz, "r:gz") as tar:
tar.extractall(path=voices_tgz.parent)
if sys.version_info >= (3, 12):
tar.extractall(path=voices_tgz.parent, filter='data')
else:
# Safe extraction fallback for Python < 3.12
for member in tar.getmembers():
if member.name.startswith('/') or '..' in member.name:
raise ValueError(f"Unsafe tar member: {member.name}")
tar.extract(member, path=voices_tgz.parent)

if not voices_dir.exists():
raise RuntimeError("voices.tgz did not contain a 'voices/' directory")
Expand Down
18 changes: 16 additions & 2 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,14 @@ def _get_voice_prompt_dir(voice_prompt_dir: Optional[str], hf_repo: str) -> Opti
if not voices_dir.exists():
logger.info(f"extracting {voices_tgz} to {voices_dir}")
with tarfile.open(voices_tgz, "r:gz") as tar:
tar.extractall(path=voices_tgz.parent)
if sys.version_info >= (3, 12):
tar.extractall(path=voices_tgz.parent, filter='data')
else:
# Safe extraction fallback for Python < 3.12
for member in tar.getmembers():
if member.name.startswith('/') or '..' in member.name:
raise ValueError(f"Unsafe tar member: {member.name}")
tar.extract(member, path=voices_tgz.parent)

if not voices_dir.exists():
raise RuntimeError("voices.tgz did not contain a 'voices/' directory")
Expand All @@ -346,7 +353,14 @@ def _get_static_path(static: Optional[str]) -> Optional[str]:
dist = dist_tgz.parent / "dist"
if not dist.exists():
with tarfile.open(dist_tgz, "r:gz") as tar:
tar.extractall(path=dist_tgz.parent)
if sys.version_info >= (3, 12):
tar.extractall(path=dist_tgz.parent, filter='data')
else:
# Safe extraction fallback for Python < 3.12
for member in tar.getmembers():
if member.name.startswith('/') or '..' in member.name:
raise ValueError(f"Unsafe tar member: {member.name}")
tar.extract(member, path=dist_tgz.parent)
return str(dist)
elif static != "none":
# When set to the "none" string, we don't serve any static content.
Expand Down