Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Our latest internal version of HeartMuLa-7B achieves **comparable performance wi

### ⚙️ Environment Setup

We recommend using `python=3.10` for local deployment.
We recommend using **Python 3.10–3.12** for local deployment (newer versions like 3.14 may not have prebuilt wheels for key deps and will try to compile from source).

Clone this repo and install locally.

Expand All @@ -72,6 +72,12 @@ cd heartlib
pip install -e .
```

Optional (recommended if you want **reference-audio conditioning** via MuQ-MuLan):

```
pip install -e ".[muq]"
```

Download our pretrained checkpoints from huggingface or modelscope using the following command:

```
Expand Down Expand Up @@ -104,6 +110,12 @@ To generate music, run:
python ./examples/run_music_generation.py --model_path=./ckpt --version="3B"
```

To enable **reference-audio conditioning** (auto-download MuQ-MuLan from Hugging Face):

```
python ./examples/run_music_generation.py --model_path=./ckpt --version="3B" --load_muq_mulan --ref_audio /path/to/ref.wav
```

By default this command will generate a piece of music conditioned on lyrics and tags provided in `./assets` folder. The output music will be saved at `./assets/output.mp3`.

All parameters:
Expand Down
16 changes: 14 additions & 2 deletions examples/run_lyrics_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@ def parse_args():

if __name__ == "__main__":
args = parse_args()

if torch.backends.mps.is_available():
device = torch.device("mps")
# MPS commonly lacks bf16 support; fp16 is the safest default.
dtype = torch.float16
elif torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
device = torch.device("cpu")
dtype = torch.bfloat16

pipe = HeartTranscriptorPipeline.from_pretrained(
args.model_path,
device=torch.device("cuda"),
dtype=torch.float16,
device=device,
dtype=dtype,
)
with torch.no_grad():
result = pipe(
Expand Down
77 changes: 74 additions & 3 deletions examples/run_music_generation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,64 @@
from heartlib import HeartMuLaGenPipeline
import argparse

import torch

from heartlib import HeartMuLaGenPipeline


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--version", type=str, default="3B")
parser.add_argument("--lyrics", type=str, default="./assets/lyrics.txt")
parser.add_argument("--tags", type=str, default="./assets/tags.txt")
parser.add_argument(
"--ref_audio",
type=str,
default=None,
help="Optional: path to reference audio for MuQ-MuLan conditioning.",
)
parser.add_argument(
"--load_muq_mulan",
action="store_true",
help="Auto-download/load MuQ-MuLan from Hugging Face (requires `pip install muq`).",
)
parser.add_argument(
"--muq_model_id",
type=str,
default="OpenMuQ/MuQ-MuLan-large",
help="Hugging Face model id for MuQ-MuLan.",
)
parser.add_argument(
"--muq_cache_dir",
type=str,
default=None,
help="Optional: Hugging Face cache dir for MuQ-MuLan.",
)
parser.add_argument(
"--muq_revision",
type=str,
default=None,
help="Optional: Hugging Face revision (branch/tag/commit) for MuQ-MuLan.",
)
parser.add_argument(
"--muq_segment_sec",
type=float,
default=10.0,
help="Reference-audio segment length (seconds) fed to MuQ.",
)
parser.add_argument(
"--muq_sample_rate",
type=int,
default=24000,
help="Sample rate expected by MuQ (usually 24 kHz).",
)
parser.add_argument("--save_path", type=str, default="./assets/output.mp3")
parser.add_argument(
"--codes_path",
type=str,
default=None,
help="Optional: save generated audio token frames (torch .pt) for analysis.",
)

parser.add_argument("--max_audio_length_ms", type=int, default=240_000)
parser.add_argument("--topk", type=int, default=50)
Expand All @@ -20,22 +69,44 @@ def parse_args():

if __name__ == "__main__":
args = parse_args()

if torch.backends.mps.is_available():
device = torch.device("mps")
# MPS commonly lacks bf16 support; fp16 is the safest default.
dtype = torch.float16
elif torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
device = torch.device("cpu")
dtype = torch.bfloat16

pipe = HeartMuLaGenPipeline.from_pretrained(
args.model_path,
device=torch.device("cuda"),
dtype=torch.bfloat16,
device=device,
dtype=dtype,
version=args.version,
load_muq_mulan=args.load_muq_mulan,
muq_model_id=args.muq_model_id,
muq_cache_dir=args.muq_cache_dir,
muq_revision=args.muq_revision,
)
with torch.no_grad():
pipe(
{
"lyrics": args.lyrics,
"tags": args.tags,
"ref_audio": args.ref_audio,
"muq_segment_sec": args.muq_segment_sec,
"muq_sample_rate": args.muq_sample_rate,
},
max_audio_length_ms=args.max_audio_length_ms,
save_path=args.save_path,
codes_path=args.codes_path,
topk=args.topk,
temperature=args.temperature,
cfg_scale=args.cfg_scale,
)
print(f"Generated music saved to {args.save_path}")
if args.codes_path:
print(f"Saved audio token frames to {args.codes_path}")
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ name = "heartlib"
version = "0.1.0"
description = "A Python Library."
readme = "README.md"
requires-python = ">=3.9"
# Torch/torchaudio and many scientific wheels may not be available on bleeding-edge
# Python versions (e.g. 3.14), which forces fragile source builds.
requires-python = ">=3.9,<3.13"
license = {text = "CC-BY-NC-4.0"}
authors = [
{name = "HeartMuLa Team", email = "heartmula.ai@gmail.com"}
Expand Down Expand Up @@ -39,6 +41,10 @@ classifiers = [
"Operating System :: OS Independent"
]

[project.optional-dependencies]
# Optional: enables auto-download + inference for MuQ-MuLan reference-audio conditioning.
muq = ["muq"]

[tool.setuptools]
package-dir = {"" = "src"}

Expand Down
24 changes: 24 additions & 0 deletions src/heartlib/accelerators/metal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Optional Metal (MPS) fused kernels for Apple Silicon.

This is intentionally self-contained and opt-in:
- No import-time dependency on Xcode toolchains.
- The extension is built on-demand via `torch.utils.cpp_extension` when enabled.
"""

from __future__ import annotations

from .runtime import metal_supported, metal_build_tools_available
from .jit import load_heartlib_metal_ops
from .rmsnorm import metal_rmsnorm_available, rmsnorm_fp16
from .rope import metal_rope_available, rope_fp16

__all__ = [
"metal_supported",
"metal_build_tools_available",
"load_heartlib_metal_ops",
"metal_rmsnorm_available",
"rmsnorm_fp16",
"metal_rope_available",
"rope_fp16",
]

143 changes: 143 additions & 0 deletions src/heartlib/accelerators/metal/jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""JIT build + load the Metal extension.

Built only when explicitly enabled. Requires Xcode command line tools.
"""

from __future__ import annotations

from pathlib import Path
import subprocess
from typing import Any

from .runtime import metal_build_tools_available, metal_supported


def _this_dir() -> Path:
return Path(__file__).resolve().parent


_CACHED_MOD: Any | None = None
_CACHED_ERR: Exception | None = None


def _xcrun_find(tool: str) -> str:
out = subprocess.check_output(
["xcrun", "-sdk", "macosx", "--find", str(tool)], stderr=subprocess.STDOUT
)
p = out.decode("utf-8", errors="replace").strip()
if not p:
raise RuntimeError(f"xcrun returned empty path for tool {tool!r}")
return p


def _compile_metallib(*, out_dir: Path, verbose: bool) -> Path:
"""Compile minimal Metal shaders -> `heartlib_ops.metallib` in `out_dir`."""
sources = [
_this_dir() / "rmsnorm.metal",
_this_dir() / "rope.metal",
]
airs = [out_dir / f"{src.stem}.air" for src in sources]
metallib = out_dir / "heartlib_ops.metallib"

metal = _xcrun_find("metal")
metallib_tool = _xcrun_find("metallib")

if metallib.exists():
mt = metallib.stat().st_mtime
if all(mt >= src.stat().st_mtime for src in sources):
return metallib

out_dir.mkdir(parents=True, exist_ok=True)

for src, air in zip(sources, airs, strict=True):
cmd = [metal, "-c", str(src), "-o", str(air)]
if verbose:
print("[heartlib] compiling Metal shader:", " ".join(cmd))
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(
"Failed to compile Metal shaders.\n\n"
f"Command:\n {' '.join(cmd)}\n\n"
f"stdout:\n{proc.stdout}\n\n"
f"stderr:\n{proc.stderr}\n"
)

cmd2 = [metallib_tool, *[str(air) for air in airs], "-o", str(metallib)]
if verbose:
print("[heartlib] linking Metal metallib:", " ".join(cmd2))
proc2 = subprocess.run(cmd2, capture_output=True, text=True)
if proc2.returncode != 0:
raise RuntimeError(
"Failed to link Metal metallib (`metallib`).\n\n"
f"Command:\n {' '.join(cmd2)}\n\n"
f"stdout:\n{proc2.stdout}\n\n"
f"stderr:\n{proc2.stderr}\n"
)
return metallib


def load_heartlib_metal_ops(*, verbose: bool = False) -> Any:
"""Build (if needed) and import the `heartlib_metal_ops` extension."""
global _CACHED_MOD, _CACHED_ERR
if _CACHED_MOD is not None:
return _CACHED_MOD
if _CACHED_ERR is not None:
raise _CACHED_ERR

if not metal_supported():
err = RuntimeError("Metal/MPS is not supported on this runtime")
_CACHED_ERR = err
raise err
if not metal_build_tools_available():
err = RuntimeError(
"Metal build tools unavailable.\n\n"
"heartlib's fused Metal kernels require Xcode's Metal toolchain (`metal`, `metallib`).\n"
"Install/select it:\n"
" - `xcode-select --install`\n"
" - or install Xcode.app then:\n"
" `sudo xcode-select -s /Applications/Xcode.app/Contents/Developer`\n"
" `sudo xcodebuild -license accept`\n\n"
"Verify:\n"
" `xcrun -sdk macosx --find metal`\n"
" `xcrun -sdk macosx --find metallib`\n"
)
_CACHED_ERR = err
raise err

import torch.utils.cpp_extension as ce

try:
name = "heartlib_metal_ops"
build_dir = Path(ce._get_build_directory(name, verbose=verbose))

_compile_metallib(out_dir=build_dir, verbose=verbose)

src_ops = str(_this_dir() / "ops.mm")
extra_cflags = [
"-O3",
"-std=c++17",
"-fobjc-arc",
]
extra_ldflags = [
"-framework",
"Metal",
"-framework",
"Foundation",
]
mod = ce.load(
name=name,
sources=[src_ops],
extra_cflags=extra_cflags,
extra_ldflags=extra_ldflags,
with_cuda=False,
is_python_module=True,
build_directory=str(build_dir),
verbose=verbose,
)
except Exception as e:
_CACHED_ERR = e
raise

_CACHED_MOD = mod
return mod

Loading