From c2fc45ce307e3f09549615b515fe1e2047e51e52 Mon Sep 17 00:00:00 2001 From: Daniel Owen van Dommelen Date: Mon, 19 Jan 2026 13:04:01 +0100 Subject: [PATCH 1/6] Refactor music generation pipeline to support dynamic device and dtype selection. Update argument handling in `run_music_generation.py` and improve `HeartMuLaGenPipeline` class for better input processing and model execution. --- examples/run_music_generation.py | 20 ++++- src/heartlib/pipelines/music_generation.py | 99 +++++++++++++--------- 2 files changed, 78 insertions(+), 41 deletions(-) diff --git a/examples/run_music_generation.py b/examples/run_music_generation.py index e84e148..710e382 100644 --- a/examples/run_music_generation.py +++ b/examples/run_music_generation.py @@ -1,7 +1,9 @@ -from heartlib import HeartMuLaGenPipeline import argparse + import torch +from heartlib import HeartMuLaGenPipeline + def parse_args(): parser = argparse.ArgumentParser() @@ -20,10 +22,22 @@ def parse_args(): if __name__ == "__main__": args = parse_args() + + if torch.backends.mps.is_available(): + device = torch.device("mps") + # MPS commonly lacks bf16 support; fp16 is the safest default. + dtype = torch.float16 + elif torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 + else: + device = torch.device("cpu") + dtype = torch.bfloat16 + pipe = HeartMuLaGenPipeline.from_pretrained( args.model_path, - device=torch.device("cuda"), - dtype=torch.bfloat16, + device=device, + dtype=dtype, version=args.version, ) with torch.no_grad(): diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index 0e5f971..0fb7145 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -1,15 +1,19 @@ -from transformers.pipelines.base import Pipeline -from tokenizers import Tokenizer -from ..heartmula.modeling_heartmula import HeartMuLa -from ..heartcodec.modeling_heartcodec import HeartCodec -import torch -from typing import Dict, Any, Optional +import json import os from dataclasses import dataclass -from tqdm import tqdm +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import torch import torchaudio -import json +from tokenizers import Tokenizer +from tqdm import tqdm from transformers import BitsAndBytesConfig +from transformers.pipelines.base import Pipeline +from transformers.utils.generic import ModelOutput + +from ..heartcodec.modeling_heartcodec import HeartCodec +from ..heartmula.modeling_heartmula import HeartMuLa @dataclass @@ -37,12 +41,13 @@ def __init__( device: torch.device, dtype: torch.dtype, ): - super().__init__(model, dtype=dtype) + super().__init__(model, device=device, dtype=dtype) self.model = model self.audio_codec = audio_codec self.muq_mulan = muq_mulan self.text_tokenizer = text_tokenizer self.config = config + self._device = device self._parallel_number = audio_codec.config.num_quantizers + 1 self._muq_dim = model.config.muq_dim @@ -60,17 +65,16 @@ def _sanitize_parameters(self, **kwargs): } return preprocess_kwargs, forward_kwargs, postprocess_kwargs - def preprocess(self, inputs: Dict[str, Any], cfg_scale: float): + def preprocess(self, input_: Dict[str, Any], **preprocess_parameters: Any): + cfg_scale: float = preprocess_parameters.get("cfg_scale", 1.5) - # process tags - tags = inputs["tags"] + tags = input_["tags"] if os.path.isfile(tags): with open(tags, encoding="utf-8") as fp: tags = fp.read() assert isinstance(tags, str), f"tags must be a string, but got {type(tags)}" tags = tags.lower() - # encapsulate with special and tokens if not tags.startswith(""): tags = f"{tags}" if not tags.endswith(""): @@ -82,15 +86,13 @@ def preprocess(self, inputs: Dict[str, Any], cfg_scale: float): if tags_ids[-1] != self.config.text_eos_id: tags_ids = tags_ids + [self.config.text_eos_id] - # process reference audio - ref_audio = inputs.get("ref_audio", None) + ref_audio = input_.get("ref_audio", None) if ref_audio is not None: raise NotImplementedError("ref_audio is not supported yet.") muq_embed = torch.zeros([self._muq_dim], dtype=self.dtype) - muq_idx = len(tags_ids) + muq_idx = len(tags) - # process lyrics - lyrics = inputs["lyrics"] + lyrics = input_["lyrics"] if os.path.isfile(lyrics): with open(lyrics, encoding="utf-8") as fp: lyrics = fp.read() @@ -105,7 +107,6 @@ def preprocess(self, inputs: Dict[str, Any], cfg_scale: float): if lyrics_ids[-1] != self.config.text_eos_id: lyrics_ids = lyrics_ids + [self.config.text_eos_id] - # cat them together. tags, ref_audio, lyrics prompt_len = len(tags_ids) + 1 + len(lyrics_ids) tokens = torch.zeros([prompt_len, self._parallel_number], dtype=torch.long) @@ -117,9 +118,9 @@ def preprocess(self, inputs: Dict[str, Any], cfg_scale: float): bs_size = 2 if cfg_scale != 1.0 else 1 - def _cfg_cat(tensor: torch.Tensor, cfg_scale: float): + def _cfg_cat(tensor: torch.Tensor, scale: float) -> torch.Tensor: tensor = tensor.unsqueeze(0) - if cfg_scale != 1.0: + if scale != 1.0: tensor = torch.cat([tensor, tensor], dim=0) return tensor @@ -133,23 +134,38 @@ def _cfg_cat(tensor: torch.Tensor, cfg_scale: float): def _forward( self, - model_inputs: Dict[str, Any], - max_audio_length_ms: int, - temperature: float, - topk: int, - cfg_scale: float, - ): - prompt_tokens = model_inputs["tokens"] - prompt_tokens_mask = model_inputs["tokens_mask"] - continuous_segment = model_inputs["muq_embed"] - starts = model_inputs["muq_idx"] - prompt_pos = model_inputs["pos"] + input_tensors: Dict[str, Any], + **forward_parameters: Any, + ) -> ModelOutput: + max_audio_length_ms: int = forward_parameters.get( + "max_audio_length_ms", 120_000 + ) + temperature: float = forward_parameters.get("temperature", 1.0) + topk: int = forward_parameters.get("topk", 50) + cfg_scale: float = forward_parameters.get("cfg_scale", 1.5) + + prompt_tokens = input_tensors["tokens"] + prompt_tokens_mask = input_tensors["tokens_mask"] + continuous_segment = input_tensors["muq_embed"] + starts = input_tensors["muq_idx"] + prompt_pos = input_tensors["pos"] frames = [] bs_size = 2 if cfg_scale != 1.0 else 1 self.model.setup_caches(bs_size) - with torch.autocast(device_type=self.device.type, dtype=self.dtype): + + device_type = ( + self._device.type if isinstance(self._device, torch.device) else "cpu" + ) + use_autocast = device_type in ("cuda", "cpu") + + autocast_ctx = ( + torch.autocast(device_type=device_type, dtype=self.dtype) + if use_autocast + else nullcontext() + ) + with autocast_ctx: curr_token = self.model.generate_frame( tokens=prompt_tokens, tokens_mask=prompt_tokens_mask, @@ -162,7 +178,7 @@ def _forward( ) frames.append(curr_token[0:1,]) - def _pad_audio_token(token: torch.Tensor): + def _pad_audio_token(token: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: padded_token = ( torch.ones( (token.shape[0], self._parallel_number), @@ -183,7 +199,12 @@ def _pad_audio_token(token: torch.Tensor): for i in tqdm(range(max_audio_frames)): curr_token, curr_token_mask = _pad_audio_token(curr_token) - with torch.autocast(device_type=self.device.type, dtype=self.dtype): + autocast_ctx = ( + torch.autocast(device_type=device_type, dtype=self.dtype) + if use_autocast + else nullcontext() + ) + with autocast_ctx: curr_token = self.model.generate_frame( tokens=curr_token, tokens_mask=curr_token_mask, @@ -199,9 +220,12 @@ def _pad_audio_token(token: torch.Tensor): frames.append(curr_token[0:1,]) frames = torch.stack(frames).permute(1, 2, 0).squeeze(0) wav = self.audio_codec.detokenize(frames) - return {"wav": wav} + return ModelOutput(wav=wav) - def postprocess(self, model_outputs: Dict[str, Any], save_path: str): + def postprocess( + self, model_outputs: ModelOutput, **postprocess_parameters: Any + ) -> None: + save_path: str = postprocess_parameters.get("save_path", "output.mp3") wav = model_outputs["wav"] torchaudio.save(save_path, wav, 48000) @@ -214,7 +238,6 @@ def from_pretrained( version: str, bnb_config: Optional[BitsAndBytesConfig] = None, ): - if os.path.exists( heartcodec_path := os.path.join(pretrained_path, "HeartCodec-oss") ): From 55a74954729c40555d189feafa81ebad1d65a218 Mon Sep 17 00:00:00 2001 From: Daniel Owen van Dommelen Date: Mon, 19 Jan 2026 13:12:37 +0100 Subject: [PATCH 2/6] Enhance device and dtype handling in lyrics transcription and heart codec model. Update `run_lyrics_transcription.py` to dynamically select device based on availability, and modify `HeartCodec` to determine device from input tensor or model parameters. Improve `HeartMuLaGenPipeline` to support autocast on MPS for better performance. --- examples/run_lyrics_transcription.py | 16 ++++++++++++++-- src/heartlib/heartcodec/modeling_heartcodec.py | 13 ++++++++++++- src/heartlib/pipelines/music_generation.py | 3 ++- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/examples/run_lyrics_transcription.py b/examples/run_lyrics_transcription.py index 4f2c2f8..e3133b1 100644 --- a/examples/run_lyrics_transcription.py +++ b/examples/run_lyrics_transcription.py @@ -13,10 +13,22 @@ def parse_args(): if __name__ == "__main__": args = parse_args() + + if torch.backends.mps.is_available(): + device = torch.device("mps") + # MPS commonly lacks bf16 support; fp16 is the safest default. + dtype = torch.float16 + elif torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 + else: + device = torch.device("cpu") + dtype = torch.bfloat16 + pipe = HeartTranscriptorPipeline.from_pretrained( args.model_path, - device=torch.device("cuda"), - dtype=torch.float16, + device=device, + dtype=dtype, ) with torch.no_grad(): result = pipe( diff --git a/src/heartlib/heartcodec/modeling_heartcodec.py b/src/heartlib/heartcodec/modeling_heartcodec.py index e1bc613..8510192 100644 --- a/src/heartlib/heartcodec/modeling_heartcodec.py +++ b/src/heartlib/heartcodec/modeling_heartcodec.py @@ -5,6 +5,7 @@ from transformers.modeling_utils import PreTrainedModel import math import numpy as np +from typing import Optional, Union class HeartCodec(PreTrainedModel): @@ -62,8 +63,18 @@ def detokenize( num_steps=10, disable_progress=False, guidance_scale=1.25, - device="cuda", + device: Optional[Union[str, torch.device]] = None, ): + if device is None: + # Prefer the input tensor device; fall back to the model's parameters. + if isinstance(codes, torch.Tensor): + device = codes.device + else: + try: + device = next(self.parameters()).device + except StopIteration: + device = torch.device("cpu") + codes = codes.unsqueeze(0).to(device) first_latent = torch.randn(codes.shape[0], int(duration * 25), 256).to( device diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index 0fb7145..993c42b 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -158,7 +158,8 @@ def _forward( device_type = ( self._device.type if isinstance(self._device, torch.device) else "cpu" ) - use_autocast = device_type in ("cuda", "cpu") + # Autocast is supported on MPS as well (and is important for perf/memory). + use_autocast = device_type in ("cuda", "cpu", "mps") autocast_ctx = ( torch.autocast(device_type=device_type, dtype=self.dtype) From 082f715405d379ae713bd9d3976741fd65717ec8 Mon Sep 17 00:00:00 2001 From: Daniel Owen van Dommelen Date: Mon, 19 Jan 2026 13:30:44 +0100 Subject: [PATCH 3/6] Refactor `HeartMuLaGenPipeline` to improve autocast handling and optimize audio token padding. Introduce a context manager for autocast that gracefully handles unsupported cases, and preallocate buffers for audio tokens to enhance performance during generation. --- src/heartlib/pipelines/music_generation.py | 83 ++++++++++++---------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index 993c42b..78e53ff 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -158,15 +158,20 @@ def _forward( device_type = ( self._device.type if isinstance(self._device, torch.device) else "cpu" ) - # Autocast is supported on MPS as well (and is important for perf/memory). - use_autocast = device_type in ("cuda", "cpu", "mps") + # Autocast support varies by PyTorch build/version (not all support "mps"). + # Prefer autocast when available, but never fail if unsupported. + def _autocast_ctx(): + try: + return torch.autocast(device_type=device_type, dtype=self.dtype) + except (RuntimeError, TypeError, ValueError): + return nullcontext() - autocast_ctx = ( - torch.autocast(device_type=device_type, dtype=self.dtype) - if use_autocast - else nullcontext() - ) - with autocast_ctx: + autocast_ctx = _autocast_ctx() + + # Keep a stable view of the base position tensor to avoid re-slicing every step. + base_pos = prompt_pos[..., -1:] + + with torch.inference_mode(), autocast_ctx: curr_token = self.model.generate_frame( tokens=prompt_tokens, tokens_mask=prompt_tokens_mask, @@ -177,49 +182,49 @@ def _forward( continuous_segments=continuous_segment, starts=starts, ) - frames.append(curr_token[0:1,]) - - def _pad_audio_token(token: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - padded_token = ( - torch.ones( - (token.shape[0], self._parallel_number), - device=token.device, - dtype=torch.long, - ) - * self.config.empty_id + + # Preallocate the padded audio token + mask and reuse them every step. + padded_token = torch.full( + (curr_token.shape[0], 1, self._parallel_number), + fill_value=self.config.empty_id, + device=curr_token.device, + dtype=torch.long, ) - padded_token[:, :-1] = token - padded_token = padded_token.unsqueeze(1) - padded_token_mask = torch.ones_like( - padded_token, device=token.device, dtype=torch.bool + padded_token_mask = torch.ones( + (curr_token.shape[0], 1, self._parallel_number), + device=curr_token.device, + dtype=torch.bool, ) padded_token_mask[..., -1] = False - return padded_token, padded_token_mask - - max_audio_frames = max_audio_length_ms // 80 - for i in tqdm(range(max_audio_frames)): - curr_token, curr_token_mask = _pad_audio_token(curr_token) - autocast_ctx = ( - torch.autocast(device_type=device_type, dtype=self.dtype) - if use_autocast - else nullcontext() + max_audio_frames = max_audio_length_ms // 80 + # Preallocate a frame buffer for the *un-padded* audio tokens (first sample only). + frame_buf = torch.empty( + (max_audio_frames + 1, curr_token.shape[1]), + device=curr_token.device, + dtype=curr_token.dtype, ) - with autocast_ctx: + frame_buf[0] = curr_token[0] + frame_len = 1 + + for i in tqdm(range(max_audio_frames)): + padded_token[:, 0, :-1] = curr_token curr_token = self.model.generate_frame( - tokens=curr_token, - tokens_mask=curr_token_mask, - input_pos=prompt_pos[..., -1:] + i + 1, + tokens=padded_token, + tokens_mask=padded_token_mask, + input_pos=base_pos + i + 1, temperature=temperature, topk=topk, cfg_scale=cfg_scale, continuous_segments=None, starts=None, ) - if torch.any(curr_token[0:1, :] >= self.config.audio_eos_id): - break - frames.append(curr_token[0:1,]) - frames = torch.stack(frames).permute(1, 2, 0).squeeze(0) + if torch.any(curr_token[0:1, :] >= self.config.audio_eos_id): + break + frame_buf[frame_len] = curr_token[0] + frame_len += 1 + + frames = frame_buf[:frame_len].transpose(0, 1).contiguous() wav = self.audio_codec.detokenize(frames) return ModelOutput(wav=wav) From a5f4aca17b4fe7da153b51bf0fa09dfce4706e6c Mon Sep 17 00:00:00 2001 From: Daniel Owen van Dommelen Date: Mon, 19 Jan 2026 13:55:43 +0100 Subject: [PATCH 4/6] Implement Metal support in `torchtune_metal.py` for optimized inference on MPS. Update `pyproject.toml` to include the optimizer package directory. Enhance `HeartMuLaGenPipeline` to optionally enable Metal optimizations during model execution, improving performance for Llama blocks. --- pyproject.toml | 2 +- src/heartlib/accelerators/torchtune_metal.py | 226 ++++++++++++++++++ src/heartlib/heartcodec/models/transformer.py | 11 +- src/heartlib/pipelines/music_generation.py | 13 + 4 files changed, 249 insertions(+), 3 deletions(-) create mode 100644 src/heartlib/accelerators/torchtune_metal.py diff --git a/pyproject.toml b/pyproject.toml index 140fe04..dd152f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ classifiers = [ ] [tool.setuptools] -package-dir = {"" = "src"} +package-dir = {"" = "src", "optimizer" = "optimizer"} [tool.setuptools.packages.find] where = ["src"] diff --git a/src/heartlib/accelerators/torchtune_metal.py b/src/heartlib/accelerators/torchtune_metal.py new file mode 100644 index 0000000..bf26b72 --- /dev/null +++ b/src/heartlib/accelerators/torchtune_metal.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + + +@dataclass(frozen=True) +class MetalPatchReport: + rmsnorm_replaced: int = 0 + rope_wrapped: int = 0 + enabled: bool = False + reason: str = "" + + +class _MetalRMSNorm(nn.Module): + """Drop-in replacement for torchtune.modules.rms_norm.RMSNorm (inference-safe).""" + + def __init__(self, *, scale: nn.Parameter, eps: float): + super().__init__() + self.eps = float(eps) + # Keep parameter name stable for state_dict compatibility. + self.scale = scale + + # Lazy import: optimizer is an optional local package. + self._metal_impl = None + try: + from optimizer.metal.rmsnorm import rmsnorm_fp16 as _metal_rmsnorm + + self._metal_impl = _metal_rmsnorm + except Exception: + self._metal_impl = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if ( + self._metal_impl is not None + and x.device.type == "mps" + and x.dtype in (torch.float16, torch.float32) + ): + return self._metal_impl(x=x, weight=self.scale, eps=self.eps) + + # Fallback matches torchtune RMSNorm implementation (compute in fp32). + x_fp32 = x.float() + x_normed = ( + x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) + ).type_as(x) + return x_normed * self.scale + + +class _MetalLlama3ScaledRoPE(nn.Module): + """Wrapper for torchtune.models.llama3_1._position_embeddings.Llama3ScaledRoPE. + + Uses Metal RoPE when the input positions are compatible with the kernel: + - Training/unpacked: input_pos is None (positions 0..S-1) + - Prompt paths that pass input_pos but are still 0..S-1 (same for all batches) + - Decode step: input_pos is a scalar position (same for all batches) with seq_len==1 + Otherwise, falls back to the original implementation. + """ + + def __init__(self, inner: nn.Module): + super().__init__() + # Avoid accidental wrapper-of-wrapper chains. + while isinstance(inner, _MetalLlama3ScaledRoPE): + inner = inner.inner + self.inner = inner + + self._metal_impl = None + try: + from optimizer.metal.rope import rope_fp16 as _metal_rope + + self._metal_impl = _metal_rope + except Exception: + self._metal_impl = None + + def _maybe_expand_cache(self, need_len: int) -> None: + cache = getattr(self.inner, "cache", None) + if cache is None: + return + if int(cache.shape[0]) >= int(need_len): + return + # Rebuild cache to the required length. + build = getattr(self.inner, "build_rope_cache", None) + if callable(build): + build(int(need_len)) + + def _inner_rope(self) -> nn.Module: + # Unwrap nested wrappers (defensive) and guard against cycles. + inner: nn.Module = self.inner + seen: set[int] = set() + while isinstance(inner, _MetalLlama3ScaledRoPE): + if id(inner) in seen: + break + seen.add(id(inner)) + inner = inner.inner + return inner + + def forward(self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + inner = self._inner_rope() + if ( + self._metal_impl is None + or x.device.type != "mps" + or x.dtype not in (torch.float16, torch.float32) + ): + return inner(x, input_pos=input_pos) + + # torchtune shape: [b, s, n_h, h_d] + if x.ndim != 4: + return inner(x, input_pos=input_pos) + + b, s, nh, hd = (int(x.shape[0]), int(x.shape[1]), int(x.shape[2]), int(x.shape[3])) + rot_dim = int(getattr(inner, "dim", hd)) + if rot_dim <= 0 or (rot_dim % 2) != 0 or rot_dim > hd: + return inner(x, input_pos=input_pos) + + cache = getattr(inner, "cache", None) + if cache is None: + return inner(x, input_pos=input_pos) + + # Fast-path selection for cos/sin. + cos: Optional[torch.Tensor] = None + sin: Optional[torch.Tensor] = None + + if input_pos is None: + self._maybe_expand_cache(s) + cos = cache[:s, :, 0] + sin = cache[:s, :, 1] + else: + ip = input_pos + # Common prompt path: ip is [B,S] and equals arange(S) (same for all batches). + if ip.ndim == 2 and int(ip.shape[1]) == s: + # Only accept when identical across batch and sequential. + ip0 = ip[0] + ar = torch.arange(s, device=ip.device, dtype=ip.dtype) + if torch.equal(ip0, ar) and torch.all(ip == ip0): + self._maybe_expand_cache(s) + cos = cache[:s, :, 0] + sin = cache[:s, :, 1] + # Decode path: ip is [B,1] and identical across batch; use a single-row cos/sin. + if cos is None and ip.numel() == b and ip.ndim == 2 and int(ip.shape[1]) == 1 and s == 1: + v0 = ip.view(-1)[0] + if torch.all(ip == v0): + pos = int(v0.item()) + self._maybe_expand_cache(pos + 1) + cos = cache[pos : pos + 1, :, 0] + sin = cache[pos : pos + 1, :, 1] + + if cos is None or sin is None: + return inner(x, input_pos=input_pos) + + # Convert torchtune layout [B,S,H,D] -> kernel layout [B,H,S,D] + x2 = x.permute(0, 2, 1, 3).contiguous() + y2 = self._metal_impl(x=x2, cos=cos, sin=sin, rot_dim=rot_dim) + return y2.permute(0, 2, 1, 3).contiguous() + + +def try_enable_torchtune_metal( + model: nn.Module, + *, + enabled: Optional[bool] = None, + verbose: bool = False, +) -> MetalPatchReport: + """Best-effort: patch torchtune Llama3.* modules to use Metal RMSNorm/RoPE on MPS. + + This is intentionally opt-in and safe: + - If optimizer/metal is missing, does nothing + - If torchtune internals differ, does nothing + - Falls back to original ops when the kernel can't represent input_pos layouts + """ + if enabled is None: + enabled = os.getenv("HEARTLIB_ENABLE_MPS_METAL", "0") == "1" + if not enabled: + return MetalPatchReport(enabled=False, reason="disabled") + + try: + tt_rms_mod = __import__("torchtune.modules.rms_norm", fromlist=["RMSNorm"]) + TT_RMSNorm = getattr(tt_rms_mod, "RMSNorm") + tt_rope_mod = __import__( + "torchtune.models.llama3_1._position_embeddings", + fromlist=["Llama3ScaledRoPE"], + ) + TT_RoPE = getattr(tt_rope_mod, "Llama3ScaledRoPE") + except Exception as e: + return MetalPatchReport(enabled=False, reason=f"torchtune import failed: {e}") + + rms_count = 0 + rope_count = 0 + rope_wrappers: Dict[int, nn.Module] = {} + + for parent in model.modules(): + # Never patch inside our own wrappers; that can create wrapper chains/cycles. + if isinstance(parent, _MetalLlama3ScaledRoPE): + continue + for name, child in list(parent.named_children()): + # Replace RMSNorm. + if isinstance(child, TT_RMSNorm) and not isinstance(child, _MetalRMSNorm): + scale = getattr(child, "scale", None) + eps = float(getattr(child, "eps", 1e-6)) + if isinstance(scale, nn.Parameter): + setattr(parent, name, _MetalRMSNorm(scale=scale, eps=eps)) + rms_count += 1 + continue + + # Wrap RoPE. + if isinstance(child, TT_RoPE) and not isinstance(child, _MetalLlama3ScaledRoPE): + key = id(child) + wrapped = rope_wrappers.get(key) + if wrapped is None: + wrapped = _MetalLlama3ScaledRoPE(child) + rope_wrappers[key] = wrapped + setattr(parent, name, wrapped) + rope_count += 1 + + if verbose: + print( + f"[heartlib] torchtune metal patch: rmsnorm_replaced={rms_count}, rope_wrapped={rope_count}" + ) + return MetalPatchReport( + rmsnorm_replaced=rms_count, + rope_wrapped=rope_count, + enabled=True, + reason="ok", + ) + diff --git a/src/heartlib/heartcodec/models/transformer.py b/src/heartlib/heartcodec/models/transformer.py index fe9918c..6390a1c 100644 --- a/src/heartlib/heartcodec/models/transformer.py +++ b/src/heartlib/heartcodec/models/transformer.py @@ -419,11 +419,18 @@ def __init__(self, embedding_dim: int, size_emb_dim: int): def timestep_embedding(self, timesteps, max_period=10000, scale=1000): half = self.flow_t_size // 2 + # NOTE: `.type(timesteps.type())` breaks on MPS (e.g. "torch.mps.FloatTensor"). + # Use an explicit dtype conversion instead. freqs = torch.exp( -math.log(max_period) - * torch.arange(start=0, end=half, device=timesteps.device) + * torch.arange( + start=0, + end=half, + device=timesteps.device, + dtype=torch.float32, + ) / half - ).type(timesteps.type()) + ).to(dtype=timesteps.dtype) args = timesteps[:, None] * freqs[None] * scale embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if self.flow_t_size % 2: diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index 78e53ff..3f9d3b3 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -14,6 +14,7 @@ from ..heartcodec.modeling_heartcodec import HeartCodec from ..heartmula.modeling_heartmula import HeartMuLa +from ..accelerators.torchtune_metal import try_enable_torchtune_metal @dataclass @@ -49,6 +50,18 @@ def __init__( self.config = config self._device = device + # Optional, opt-in MPS fast path (custom Metal kernels) for torchtune Llama blocks. + # Enable with: HEARTLIB_ENABLE_MPS_METAL=1 + try: + try_enable_torchtune_metal( + self.model, + enabled=(os.getenv("HEARTLIB_ENABLE_MPS_METAL", "0") == "1"), + verbose=(os.getenv("HEARTLIB_MPS_METAL_VERBOSE", "0") == "1"), + ) + except Exception: + # Never fail inference if optional kernels are unavailable. + pass + self._parallel_number = audio_codec.config.num_quantizers + 1 self._muq_dim = model.config.muq_dim From b56ec87a9ade14d22014955731f670e334778e51 Mon Sep 17 00:00:00 2001 From: Daniel Owen van Dommelen Date: Mon, 19 Jan 2026 14:13:45 +0100 Subject: [PATCH 5/6] Implement Metal support for RMSNorm and RoPE operations, including new Metal kernels and Python wrappers. Update `pyproject.toml` to remove the optimizer package directory. Enhance runtime detection for Metal support and build tools availability. --- pyproject.toml | 2 +- src/heartlib/accelerators/metal/__init__.py | 24 + src/heartlib/accelerators/metal/jit.py | 143 ++++++ src/heartlib/accelerators/metal/ops.mm | 478 ++++++++++++++++++ src/heartlib/accelerators/metal/rmsnorm.metal | 179 +++++++ src/heartlib/accelerators/metal/rmsnorm.py | 110 ++++ src/heartlib/accelerators/metal/rope.metal | 127 +++++ src/heartlib/accelerators/metal/rope.py | 102 ++++ src/heartlib/accelerators/metal/runtime.py | 51 ++ src/heartlib/accelerators/torchtune_metal.py | 4 +- 10 files changed, 1217 insertions(+), 3 deletions(-) create mode 100644 src/heartlib/accelerators/metal/__init__.py create mode 100644 src/heartlib/accelerators/metal/jit.py create mode 100644 src/heartlib/accelerators/metal/ops.mm create mode 100644 src/heartlib/accelerators/metal/rmsnorm.metal create mode 100644 src/heartlib/accelerators/metal/rmsnorm.py create mode 100644 src/heartlib/accelerators/metal/rope.metal create mode 100644 src/heartlib/accelerators/metal/rope.py create mode 100644 src/heartlib/accelerators/metal/runtime.py diff --git a/pyproject.toml b/pyproject.toml index dd152f6..140fe04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ classifiers = [ ] [tool.setuptools] -package-dir = {"" = "src", "optimizer" = "optimizer"} +package-dir = {"" = "src"} [tool.setuptools.packages.find] where = ["src"] diff --git a/src/heartlib/accelerators/metal/__init__.py b/src/heartlib/accelerators/metal/__init__.py new file mode 100644 index 0000000..00d9a5a --- /dev/null +++ b/src/heartlib/accelerators/metal/__init__.py @@ -0,0 +1,24 @@ +"""Optional Metal (MPS) fused kernels for Apple Silicon. + +This is intentionally self-contained and opt-in: +- No import-time dependency on Xcode toolchains. +- The extension is built on-demand via `torch.utils.cpp_extension` when enabled. +""" + +from __future__ import annotations + +from .runtime import metal_supported, metal_build_tools_available +from .jit import load_heartlib_metal_ops +from .rmsnorm import metal_rmsnorm_available, rmsnorm_fp16 +from .rope import metal_rope_available, rope_fp16 + +__all__ = [ + "metal_supported", + "metal_build_tools_available", + "load_heartlib_metal_ops", + "metal_rmsnorm_available", + "rmsnorm_fp16", + "metal_rope_available", + "rope_fp16", +] + diff --git a/src/heartlib/accelerators/metal/jit.py b/src/heartlib/accelerators/metal/jit.py new file mode 100644 index 0000000..688234d --- /dev/null +++ b/src/heartlib/accelerators/metal/jit.py @@ -0,0 +1,143 @@ +"""JIT build + load the Metal extension. + +Built only when explicitly enabled. Requires Xcode command line tools. +""" + +from __future__ import annotations + +from pathlib import Path +import subprocess +from typing import Any + +from .runtime import metal_build_tools_available, metal_supported + + +def _this_dir() -> Path: + return Path(__file__).resolve().parent + + +_CACHED_MOD: Any | None = None +_CACHED_ERR: Exception | None = None + + +def _xcrun_find(tool: str) -> str: + out = subprocess.check_output( + ["xcrun", "-sdk", "macosx", "--find", str(tool)], stderr=subprocess.STDOUT + ) + p = out.decode("utf-8", errors="replace").strip() + if not p: + raise RuntimeError(f"xcrun returned empty path for tool {tool!r}") + return p + + +def _compile_metallib(*, out_dir: Path, verbose: bool) -> Path: + """Compile minimal Metal shaders -> `heartlib_ops.metallib` in `out_dir`.""" + sources = [ + _this_dir() / "rmsnorm.metal", + _this_dir() / "rope.metal", + ] + airs = [out_dir / f"{src.stem}.air" for src in sources] + metallib = out_dir / "heartlib_ops.metallib" + + metal = _xcrun_find("metal") + metallib_tool = _xcrun_find("metallib") + + if metallib.exists(): + mt = metallib.stat().st_mtime + if all(mt >= src.stat().st_mtime for src in sources): + return metallib + + out_dir.mkdir(parents=True, exist_ok=True) + + for src, air in zip(sources, airs, strict=True): + cmd = [metal, "-c", str(src), "-o", str(air)] + if verbose: + print("[heartlib] compiling Metal shader:", " ".join(cmd)) + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + raise RuntimeError( + "Failed to compile Metal shaders.\n\n" + f"Command:\n {' '.join(cmd)}\n\n" + f"stdout:\n{proc.stdout}\n\n" + f"stderr:\n{proc.stderr}\n" + ) + + cmd2 = [metallib_tool, *[str(air) for air in airs], "-o", str(metallib)] + if verbose: + print("[heartlib] linking Metal metallib:", " ".join(cmd2)) + proc2 = subprocess.run(cmd2, capture_output=True, text=True) + if proc2.returncode != 0: + raise RuntimeError( + "Failed to link Metal metallib (`metallib`).\n\n" + f"Command:\n {' '.join(cmd2)}\n\n" + f"stdout:\n{proc2.stdout}\n\n" + f"stderr:\n{proc2.stderr}\n" + ) + return metallib + + +def load_heartlib_metal_ops(*, verbose: bool = False) -> Any: + """Build (if needed) and import the `heartlib_metal_ops` extension.""" + global _CACHED_MOD, _CACHED_ERR + if _CACHED_MOD is not None: + return _CACHED_MOD + if _CACHED_ERR is not None: + raise _CACHED_ERR + + if not metal_supported(): + err = RuntimeError("Metal/MPS is not supported on this runtime") + _CACHED_ERR = err + raise err + if not metal_build_tools_available(): + err = RuntimeError( + "Metal build tools unavailable.\n\n" + "heartlib's fused Metal kernels require Xcode's Metal toolchain (`metal`, `metallib`).\n" + "Install/select it:\n" + " - `xcode-select --install`\n" + " - or install Xcode.app then:\n" + " `sudo xcode-select -s /Applications/Xcode.app/Contents/Developer`\n" + " `sudo xcodebuild -license accept`\n\n" + "Verify:\n" + " `xcrun -sdk macosx --find metal`\n" + " `xcrun -sdk macosx --find metallib`\n" + ) + _CACHED_ERR = err + raise err + + import torch.utils.cpp_extension as ce + + try: + name = "heartlib_metal_ops" + build_dir = Path(ce._get_build_directory(name, verbose=verbose)) + + _compile_metallib(out_dir=build_dir, verbose=verbose) + + src_ops = str(_this_dir() / "ops.mm") + extra_cflags = [ + "-O3", + "-std=c++17", + "-fobjc-arc", + ] + extra_ldflags = [ + "-framework", + "Metal", + "-framework", + "Foundation", + ] + mod = ce.load( + name=name, + sources=[src_ops], + extra_cflags=extra_cflags, + extra_ldflags=extra_ldflags, + with_cuda=False, + is_python_module=True, + build_directory=str(build_dir), + verbose=verbose, + ) + except Exception as e: + _CACHED_ERR = e + raise + + _CACHED_MOD = mod + return mod + diff --git a/src/heartlib/accelerators/metal/ops.mm b/src/heartlib/accelerators/metal/ops.mm new file mode 100644 index 0000000..d2178bb --- /dev/null +++ b/src/heartlib/accelerators/metal/ops.mm @@ -0,0 +1,478 @@ +#include + +#include +#include + +#include +#include +#include +#include + +#import +#import + +namespace fs = std::filesystem; + +namespace { + +// Must match `RMSNormParams` in `rmsnorm.metal`. +struct RMSNormParams { + uint32_t d_model; + float eps; + uint32_t stride_row; +}; + +// Must match `RoPEParams` in `rope.metal`. +struct RoPEParams { + uint32_t d_model; + uint32_t rot_dim; + uint32_t half_rot; + uint32_t seq_len; +}; + +constexpr NSUInteger kThreadsPerThreadgroup = 256; + +static id g_lib = nil; +static id g_pipeline_rmsnorm = nil; +static id g_pipeline_rmsnorm_fp32 = nil; +static id g_pipeline_rmsnorm_noweight = nil; +static id g_pipeline_rmsnorm_noweight_fp32 = nil; +static id g_pipeline_rmsnorm_fwd_inv = nil; +static id g_pipeline_rmsnorm_fwd_inv_fp32 = nil; +static id g_pipeline_rmsnorm_noweight_fwd_inv = nil; +static id g_pipeline_rmsnorm_noweight_fwd_inv_fp32 = nil; +static id g_pipeline_rope = nil; +static id g_pipeline_rope_fp32 = nil; +static id g_pipeline_rope_bwd = nil; +static id g_pipeline_rope_bwd_fp32 = nil; +static std::mutex g_pipeline_mutex; + +static std::string metallib_path_for_this_module() { + Dl_info info; + if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) { + return std::string(); + } + fs::path so_path(info.dli_fname); + fs::path lib_path = so_path.parent_path() / "heartlib_ops.metallib"; + return lib_path.string(); +} + +static void ensure_library_locked(id device) { + if (g_lib != nil) { + return; + } + + const std::string lib_path = metallib_path_for_this_module(); + TORCH_CHECK(!lib_path.empty(), "heartlib_metal_ops: failed to locate extension path via dladdr()"); + + NSString* ns_path = [NSString stringWithUTF8String:lib_path.c_str()]; + NSURL* url = [NSURL fileURLWithPath:ns_path]; + NSError* err = nil; + g_lib = [device newLibraryWithURL:url error:&err]; + if (g_lib == nil) { + const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error"; + TORCH_CHECK(false, "heartlib_metal_ops: failed to load metallib at ", lib_path, ": ", msg); + } +} + +static id ensure_pipeline( + id device, + id __strong* pipeline, + const char* fn_name) { + std::lock_guard lock(g_pipeline_mutex); + ensure_library_locked(device); + + if (*pipeline != nil) { + return *pipeline; + } + + NSString* ns_fn = [NSString stringWithUTF8String:fn_name]; + id fn = [g_lib newFunctionWithName:ns_fn]; + TORCH_CHECK(fn != nil, "heartlib_metal_ops: function `", fn_name, "` not found in metallib"); + + NSError* err = nil; + *pipeline = [device newComputePipelineStateWithFunction:fn error:&err]; + if (*pipeline == nil) { + const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error"; + TORCH_CHECK(false, "heartlib_metal_ops: failed to create compute pipeline: ", msg); + } + + TORCH_CHECK( + (*pipeline).maxTotalThreadsPerThreadgroup >= kThreadsPerThreadgroup, + "heartlib_metal_ops: pipeline maxTotalThreadsPerThreadgroup (", + (int)(*pipeline).maxTotalThreadsPerThreadgroup, + ") < expected threads (", + (int)kThreadsPerThreadgroup, + ")"); + return *pipeline; +} + +static inline id storage_as_mtlbuffer(const at::Tensor& t) { + const auto& dp = t.storage().data_ptr(); + void* ctx = dp.get_context(); + TORCH_CHECK( + ctx != nullptr, + "heartlib_metal_ops: expected MPS storage to provide an MTLBuffer context (got null)."); + return (__bridge id)ctx; +} + +static inline NSUInteger storage_offset_bytes(const at::Tensor& t) { + return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size()); +} + +torch::Tensor rmsnorm( + at::Tensor x, + at::Tensor weight, + double eps) { + TORCH_CHECK(x.device().is_mps(), "rmsnorm: x must be on MPS"); + TORCH_CHECK(weight.device().is_mps(), "rmsnorm: weight must be on MPS"); + TORCH_CHECK(x.dtype() == at::kHalf || x.dtype() == at::kFloat, "rmsnorm: x must be fp16 or fp32"); + TORCH_CHECK(weight.dtype() == x.dtype(), "rmsnorm: weight dtype must match x"); + TORCH_CHECK(x.is_contiguous(), "rmsnorm: x must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "rmsnorm: weight must be contiguous"); + TORCH_CHECK(x.dim() >= 1, "rmsnorm: x must have dim >= 1"); + + const int64_t D = x.size(-1); + TORCH_CHECK(D > 0, "rmsnorm: invalid last dim"); + TORCH_CHECK(weight.numel() == D, "rmsnorm: weight must have numel == x.size(-1)"); + + auto out = torch::empty_like(x); + const int64_t rows = x.numel() / D; + TORCH_CHECK(rows * D == x.numel(), "rmsnorm: x.numel must be divisible by D"); + + id device = (id)at::mps::MPSDevice::getInstance()->device(); + id pipeline = (x.dtype() == at::kFloat) + ? ensure_pipeline(device, &g_pipeline_rmsnorm_fp32, "rmsnorm_fp32") + : ensure_pipeline(device, &g_pipeline_rmsnorm, "rmsnorm_fp16"); + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, "rmsnorm: failed to get current MPS stream"); + id encoder = (id)stream->commandEncoder(); + TORCH_CHECK(encoder != nil, "rmsnorm: failed to get MTLComputeCommandEncoder from MPS stream"); + + [encoder setComputePipelineState:pipeline]; + + auto set_tensor = [&](const at::Tensor& t, int idx) { + id buf = storage_as_mtlbuffer(t); + TORCH_CHECK(buf != nil, "rmsnorm: tensor has null MTLBuffer"); + const NSUInteger off = storage_offset_bytes(t); + [encoder setBuffer:buf offset:off atIndex:(NSUInteger)idx]; + }; + + set_tensor(x, 0); + set_tensor(weight, 1); + set_tensor(out, 2); + + RMSNormParams params; + params.d_model = (uint32_t)D; + params.eps = (float)eps; + params.stride_row = (uint32_t)D; + [encoder setBytes:¶ms length:sizeof(RMSNormParams) atIndex:3]; + + const MTLSize tg = MTLSizeMake(kThreadsPerThreadgroup, 1, 1); + const MTLSize grid = MTLSizeMake((NSUInteger)rows, 1, 1); + [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + return out; +} + +torch::Tensor rmsnorm_noweight( + at::Tensor x, + double eps) { + TORCH_CHECK(x.device().is_mps(), "rmsnorm_noweight: x must be on MPS"); + TORCH_CHECK(x.dtype() == at::kHalf || x.dtype() == at::kFloat, "rmsnorm_noweight: x must be fp16 or fp32"); + TORCH_CHECK(x.is_contiguous(), "rmsnorm_noweight: x must be contiguous"); + TORCH_CHECK(x.dim() >= 1, "rmsnorm_noweight: x must have dim >= 1"); + + const int64_t D = x.size(-1); + TORCH_CHECK(D > 0, "rmsnorm_noweight: invalid last dim"); + + auto out = torch::empty_like(x); + const int64_t rows = x.numel() / D; + TORCH_CHECK(rows * D == x.numel(), "rmsnorm_noweight: x.numel must be divisible by D"); + + id device = (id)at::mps::MPSDevice::getInstance()->device(); + id pipeline = (x.dtype() == at::kFloat) + ? ensure_pipeline(device, &g_pipeline_rmsnorm_noweight_fp32, "rmsnorm_noweight_fp32") + : ensure_pipeline(device, &g_pipeline_rmsnorm_noweight, "rmsnorm_noweight_fp16"); + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, "rmsnorm_noweight: failed to get current MPS stream"); + id encoder = (id)stream->commandEncoder(); + TORCH_CHECK(encoder != nil, "rmsnorm_noweight: failed to get MTLComputeCommandEncoder from MPS stream"); + [encoder setComputePipelineState:pipeline]; + + auto set_tensor = [&](const at::Tensor& t, int idx) { + id buf = storage_as_mtlbuffer(t); + TORCH_CHECK(buf != nil, "rmsnorm_noweight: tensor has null MTLBuffer"); + const NSUInteger off = storage_offset_bytes(t); + [encoder setBuffer:buf offset:off atIndex:(NSUInteger)idx]; + }; + + set_tensor(x, 0); + set_tensor(out, 1); + + RMSNormParams params; + params.d_model = (uint32_t)D; + params.eps = (float)eps; + params.stride_row = (uint32_t)D; + [encoder setBytes:¶ms length:sizeof(RMSNormParams) atIndex:2]; + + const MTLSize tg = MTLSizeMake(kThreadsPerThreadgroup, 1, 1); + const MTLSize grid = MTLSizeMake((NSUInteger)rows, 1, 1); + [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + return out; +} + +std::vector rmsnorm_forward_with_inv( + at::Tensor x, + at::Tensor weight, + double eps) { + TORCH_CHECK(x.device().is_mps(), "rmsnorm_forward_with_inv: x must be on MPS"); + TORCH_CHECK(weight.device().is_mps(), "rmsnorm_forward_with_inv: weight must be on MPS"); + TORCH_CHECK(x.dtype() == at::kHalf || x.dtype() == at::kFloat, "rmsnorm_forward_with_inv: x must be fp16 or fp32"); + TORCH_CHECK(weight.dtype() == x.dtype(), "rmsnorm_forward_with_inv: weight dtype must match x"); + TORCH_CHECK(x.is_contiguous(), "rmsnorm_forward_with_inv: x must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "rmsnorm_forward_with_inv: weight must be contiguous"); + TORCH_CHECK(x.dim() >= 1, "rmsnorm_forward_with_inv: x must have dim >= 1"); + + const int64_t D = x.size(-1); + TORCH_CHECK(D > 0, "rmsnorm_forward_with_inv: invalid last dim"); + TORCH_CHECK(weight.numel() == D, "rmsnorm_forward_with_inv: weight must have numel == x.size(-1)"); + + auto out = torch::empty_like(x); + const int64_t rows = x.numel() / D; + TORCH_CHECK(rows * D == x.numel(), "rmsnorm_forward_with_inv: x.numel must be divisible by D"); + auto inv = torch::empty({rows}, x.options()); + + id device = (id)at::mps::MPSDevice::getInstance()->device(); + id pipeline = (x.dtype() == at::kFloat) + ? ensure_pipeline(device, &g_pipeline_rmsnorm_fwd_inv_fp32, "rmsnorm_fwd_inv_fp32") + : ensure_pipeline(device, &g_pipeline_rmsnorm_fwd_inv, "rmsnorm_fwd_inv_fp16"); + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, "rmsnorm_forward_with_inv: failed to get current MPS stream"); + id encoder = (id)stream->commandEncoder(); + TORCH_CHECK(encoder != nil, "rmsnorm_forward_with_inv: failed to get MTLComputeCommandEncoder from MPS stream"); + [encoder setComputePipelineState:pipeline]; + + auto set_tensor = [&](const at::Tensor& t, int idx) { + id buf = storage_as_mtlbuffer(t); + TORCH_CHECK(buf != nil, "rmsnorm_forward_with_inv: tensor has null MTLBuffer"); + const NSUInteger off = storage_offset_bytes(t); + [encoder setBuffer:buf offset:off atIndex:(NSUInteger)idx]; + }; + + set_tensor(x, 0); + set_tensor(weight, 1); + set_tensor(out, 2); + set_tensor(inv, 3); + + RMSNormParams params; + params.d_model = (uint32_t)D; + params.eps = (float)eps; + params.stride_row = (uint32_t)D; + [encoder setBytes:¶ms length:sizeof(RMSNormParams) atIndex:4]; + + const MTLSize tg = MTLSizeMake(kThreadsPerThreadgroup, 1, 1); + const MTLSize grid = MTLSizeMake((NSUInteger)rows, 1, 1); + [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + return {out, inv}; +} + +std::vector rmsnorm_noweight_forward_with_inv( + at::Tensor x, + double eps) { + TORCH_CHECK(x.device().is_mps(), "rmsnorm_noweight_forward_with_inv: x must be on MPS"); + TORCH_CHECK(x.dtype() == at::kHalf || x.dtype() == at::kFloat, "rmsnorm_noweight_forward_with_inv: x must be fp16 or fp32"); + TORCH_CHECK(x.is_contiguous(), "rmsnorm_noweight_forward_with_inv: x must be contiguous"); + TORCH_CHECK(x.dim() >= 1, "rmsnorm_noweight_forward_with_inv: x must have dim >= 1"); + + const int64_t D = x.size(-1); + TORCH_CHECK(D > 0, "rmsnorm_noweight_forward_with_inv: invalid last dim"); + + auto out = torch::empty_like(x); + const int64_t rows = x.numel() / D; + TORCH_CHECK(rows * D == x.numel(), "rmsnorm_noweight_forward_with_inv: x.numel must be divisible by D"); + auto inv = torch::empty({rows}, x.options()); + + id device = (id)at::mps::MPSDevice::getInstance()->device(); + id pipeline = (x.dtype() == at::kFloat) + ? ensure_pipeline(device, &g_pipeline_rmsnorm_noweight_fwd_inv_fp32, "rmsnorm_noweight_fwd_inv_fp32") + : ensure_pipeline(device, &g_pipeline_rmsnorm_noweight_fwd_inv, "rmsnorm_noweight_fwd_inv_fp16"); + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, "rmsnorm_noweight_forward_with_inv: failed to get current MPS stream"); + id encoder = (id)stream->commandEncoder(); + TORCH_CHECK(encoder != nil, "rmsnorm_noweight_forward_with_inv: failed to get MTLComputeCommandEncoder from MPS stream"); + [encoder setComputePipelineState:pipeline]; + + auto set_tensor = [&](const at::Tensor& t, int idx) { + id buf = storage_as_mtlbuffer(t); + TORCH_CHECK(buf != nil, "rmsnorm_noweight_forward_with_inv: tensor has null MTLBuffer"); + const NSUInteger off = storage_offset_bytes(t); + [encoder setBuffer:buf offset:off atIndex:(NSUInteger)idx]; + }; + + set_tensor(x, 0); + set_tensor(out, 1); + set_tensor(inv, 2); + + RMSNormParams params; + params.d_model = (uint32_t)D; + params.eps = (float)eps; + params.stride_row = (uint32_t)D; + [encoder setBytes:¶ms length:sizeof(RMSNormParams) atIndex:3]; + + const MTLSize tg = MTLSizeMake(kThreadsPerThreadgroup, 1, 1); + const MTLSize grid = MTLSizeMake((NSUInteger)rows, 1, 1); + [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + return {out, inv}; +} + +torch::Tensor rope( + at::Tensor x, + at::Tensor cos, + at::Tensor sin, + int64_t rot_dim) { + TORCH_CHECK(x.device().is_mps(), "rope: x must be on MPS"); + TORCH_CHECK(cos.device().is_mps(), "rope: cos must be on MPS"); + TORCH_CHECK(sin.device().is_mps(), "rope: sin must be on MPS"); + TORCH_CHECK(x.dtype() == at::kHalf || x.dtype() == at::kFloat, "rope: x must be fp16 or fp32"); + TORCH_CHECK(cos.dtype() == x.dtype(), "rope: cos dtype must match x"); + TORCH_CHECK(sin.dtype() == x.dtype(), "rope: sin dtype must match x"); + TORCH_CHECK(x.is_contiguous(), "rope: x must be contiguous"); + TORCH_CHECK(cos.is_contiguous(), "rope: cos must be contiguous"); + TORCH_CHECK(sin.is_contiguous(), "rope: sin must be contiguous"); + TORCH_CHECK(x.dim() == 4, "rope: x must be (B,H,T,D)"); + TORCH_CHECK(cos.dim() == 2, "rope: cos must be (T, rot/2)"); + TORCH_CHECK(sin.dim() == 2, "rope: sin must be (T, rot/2)"); + TORCH_CHECK(rot_dim > 0, "rope: rot_dim must be > 0"); + TORCH_CHECK((rot_dim % 2) == 0, "rope: rot_dim must be even"); + + const int64_t B = x.size(0); + const int64_t H = x.size(1); + const int64_t T = x.size(2); + const int64_t D = x.size(3); + TORCH_CHECK(rot_dim <= D, "rope: rot_dim must be <= head_dim"); + + const int64_t half_rot = rot_dim / 2; + TORCH_CHECK(cos.size(0) == T && cos.size(1) == half_rot, "rope: cos shape mismatch"); + TORCH_CHECK(sin.size(0) == T && sin.size(1) == half_rot, "rope: sin shape mismatch"); + + auto out = torch::empty_like(x); + const int64_t n_vec = B * H * T; + + id device = (id)at::mps::MPSDevice::getInstance()->device(); + id pipeline = (x.dtype() == at::kFloat) + ? ensure_pipeline(device, &g_pipeline_rope_fp32, "rope_fp32") + : ensure_pipeline(device, &g_pipeline_rope, "rope_fp16"); + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, "rope: failed to get current MPS stream"); + id encoder = (id)stream->commandEncoder(); + TORCH_CHECK(encoder != nil, "rope: failed to get MTLComputeCommandEncoder from MPS stream"); + [encoder setComputePipelineState:pipeline]; + + auto set_tensor = [&](const at::Tensor& t, int idx) { + id buf = storage_as_mtlbuffer(t); + TORCH_CHECK(buf != nil, "rope: tensor has null MTLBuffer"); + const NSUInteger off = storage_offset_bytes(t); + [encoder setBuffer:buf offset:off atIndex:(NSUInteger)idx]; + }; + + set_tensor(x, 0); + set_tensor(cos, 1); + set_tensor(sin, 2); + set_tensor(out, 3); + + RoPEParams params; + params.d_model = (uint32_t)D; + params.rot_dim = (uint32_t)rot_dim; + params.half_rot = (uint32_t)half_rot; + params.seq_len = (uint32_t)T; + [encoder setBytes:¶ms length:sizeof(RoPEParams) atIndex:4]; + + const MTLSize tg = MTLSizeMake(kThreadsPerThreadgroup, 1, 1); + const MTLSize grid = MTLSizeMake((NSUInteger)n_vec, 1, 1); + [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + return out; +} + +torch::Tensor rope_backward( + at::Tensor grad_y, + at::Tensor cos, + at::Tensor sin, + int64_t rot_dim) { + TORCH_CHECK(grad_y.device().is_mps(), "rope_backward: grad_y must be on MPS"); + TORCH_CHECK(cos.device().is_mps(), "rope_backward: cos must be on MPS"); + TORCH_CHECK(sin.device().is_mps(), "rope_backward: sin must be on MPS"); + TORCH_CHECK(grad_y.dtype() == at::kHalf || grad_y.dtype() == at::kFloat, "rope_backward: grad_y must be fp16 or fp32"); + TORCH_CHECK(cos.dtype() == grad_y.dtype(), "rope_backward: cos dtype must match grad_y"); + TORCH_CHECK(sin.dtype() == grad_y.dtype(), "rope_backward: sin dtype must match grad_y"); + TORCH_CHECK(grad_y.is_contiguous(), "rope_backward: grad_y must be contiguous"); + TORCH_CHECK(cos.is_contiguous(), "rope_backward: cos must be contiguous"); + TORCH_CHECK(sin.is_contiguous(), "rope_backward: sin must be contiguous"); + TORCH_CHECK(grad_y.dim() == 4, "rope_backward: grad_y must be (B,H,T,D)"); + TORCH_CHECK(rot_dim > 0, "rope_backward: rot_dim must be > 0"); + TORCH_CHECK((rot_dim % 2) == 0, "rope_backward: rot_dim must be even"); + + const int64_t B = grad_y.size(0); + const int64_t H = grad_y.size(1); + const int64_t T = grad_y.size(2); + const int64_t D = grad_y.size(3); + TORCH_CHECK(rot_dim <= D, "rope_backward: rot_dim must be <= head_dim"); + + const int64_t half_rot = rot_dim / 2; + TORCH_CHECK(cos.size(0) == T && cos.size(1) == half_rot, "rope_backward: cos shape mismatch"); + TORCH_CHECK(sin.size(0) == T && sin.size(1) == half_rot, "rope_backward: sin shape mismatch"); + + auto grad_x = torch::empty_like(grad_y); + const int64_t n_vec = B * H * T; + + id device = (id)at::mps::MPSDevice::getInstance()->device(); + id pipeline = (grad_y.dtype() == at::kFloat) + ? ensure_pipeline(device, &g_pipeline_rope_bwd_fp32, "rope_bwd_fp32") + : ensure_pipeline(device, &g_pipeline_rope_bwd, "rope_bwd_fp16"); + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, "rope_backward: failed to get current MPS stream"); + id encoder = (id)stream->commandEncoder(); + TORCH_CHECK(encoder != nil, "rope_backward: failed to get MTLComputeCommandEncoder from MPS stream"); + [encoder setComputePipelineState:pipeline]; + + auto set_tensor = [&](const at::Tensor& t, int idx) { + id buf = storage_as_mtlbuffer(t); + TORCH_CHECK(buf != nil, "rope_backward: tensor has null MTLBuffer"); + const NSUInteger off = storage_offset_bytes(t); + [encoder setBuffer:buf offset:off atIndex:(NSUInteger)idx]; + }; + + set_tensor(grad_y, 0); + set_tensor(cos, 1); + set_tensor(sin, 2); + set_tensor(grad_x, 3); + + RoPEParams params; + params.d_model = (uint32_t)D; + params.rot_dim = (uint32_t)rot_dim; + params.half_rot = (uint32_t)half_rot; + params.seq_len = (uint32_t)T; + [encoder setBytes:¶ms length:sizeof(RoPEParams) atIndex:4]; + + const MTLSize tg = MTLSizeMake(kThreadsPerThreadgroup, 1, 1); + const MTLSize grid = MTLSizeMake((NSUInteger)n_vec, 1, 1); + [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + return grad_x; +} + +} // namespace + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rmsnorm", &rmsnorm, "RMSNorm Forward (Metal/MPS)"); + m.def("rmsnorm_noweight", &rmsnorm_noweight, "RMSNorm Forward (no weight, Metal/MPS)"); + m.def("rmsnorm_forward_with_inv", &rmsnorm_forward_with_inv, "RMSNorm Forward with inv cache (Metal/MPS)"); + m.def("rmsnorm_noweight_forward_with_inv", &rmsnorm_noweight_forward_with_inv, "RMSNorm Forward (no weight) with inv cache (Metal/MPS)"); + m.def("rope", &rope, "RoPE Apply (Metal/MPS)"); + m.def("rope_backward", &rope_backward, "RoPE Backward (Metal/MPS)"); +} + diff --git a/src/heartlib/accelerators/metal/rmsnorm.metal b/src/heartlib/accelerators/metal/rmsnorm.metal new file mode 100644 index 0000000..478f402 --- /dev/null +++ b/src/heartlib/accelerators/metal/rmsnorm.metal @@ -0,0 +1,179 @@ +#include +using namespace metal; + +// Must match `RMSNormParams` in `ops.mm` (layout + types). +struct RMSNormParams { + uint d_model; + float eps; + uint stride_row; // in elements +}; + +// Must match `RMSNormGradWParams` in `ops.mm`. +struct RMSNormGradWParams { + uint d_model; + uint rows; + uint stride_row; // in elements +}; + +constant uint TG = 256; +constant uint SIMD = 32; +constant uint NSIMD = TG / SIMD; // 8 + +template +inline void rmsnorm_fwd_impl( + device const T* x, + device const T* weight, + device T* out, + device T* inv_out, + constant RMSNormParams& p, + uint tid, + uint tg_id, + threadgroup float* tg_sum, + threadgroup float* shared_inv +) { + const uint row = tg_id; + device const T* xr = x + row * p.stride_row; + device T* yr = out + row * p.stride_row; + + float sum = 0.0f; + for (uint i = tid; i < p.d_model; i += TG) { + const float v = float(xr[i]); + sum += v * v; + } + + const float sg_sum = simd_sum(sum); + const bool lane0 = (tid % SIMD) == 0; + if (lane0) { + tg_sum[tid / SIMD] = sg_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid == 0) { + float total = 0.0f; + for (uint i = 0; i < NSIMD; ++i) { + total += tg_sum[i]; + } + const float mean = total / float(p.d_model); + const float inv = rsqrt(mean + p.eps); + *shared_inv = inv; + if constexpr (HAS_BWD_INV) { + inv_out[row] = T(inv); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float inv = *shared_inv; + for (uint i = tid; i < p.d_model; i += TG) { + const float v = float(xr[i]); + float val = v * inv; + if constexpr (HAS_WEIGHT) { + val *= float(weight[i]); + } + yr[i] = T(val); + } +} + +kernel void rmsnorm_fp16( + device const half* x [[ buffer(0) ]], + device const half* weight [[ buffer(1) ]], + device half* out [[ buffer(2) ]], + constant RMSNormParams& p [[ buffer(3) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + threadgroup float tg_sum[8]; + threadgroup float shared_inv; + rmsnorm_fwd_impl(x, weight, out, nullptr, p, tid, tg_id, tg_sum, &shared_inv); +} + +kernel void rmsnorm_fwd_inv_fp16( + device const half* x [[ buffer(0) ]], + device const half* weight [[ buffer(1) ]], + device half* out [[ buffer(2) ]], + device half* inv_out [[ buffer(3) ]], + constant RMSNormParams& p [[ buffer(4) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + threadgroup float tg_sum[8]; + threadgroup float shared_inv; + rmsnorm_fwd_impl(x, weight, out, inv_out, p, tid, tg_id, tg_sum, &shared_inv); +} + +kernel void rmsnorm_noweight_fp16( + device const half* x [[ buffer(0) ]], + device half* out [[ buffer(1) ]], + constant RMSNormParams& p [[ buffer(2) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + threadgroup float tg_sum[8]; + threadgroup float shared_inv; + rmsnorm_fwd_impl(x, nullptr, out, nullptr, p, tid, tg_id, tg_sum, &shared_inv); +} + +kernel void rmsnorm_noweight_fwd_inv_fp16( + device const half* x [[ buffer(0) ]], + device half* out [[ buffer(1) ]], + device half* inv_out [[ buffer(2) ]], + constant RMSNormParams& p [[ buffer(3) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + threadgroup float tg_sum[8]; + threadgroup float shared_inv; + rmsnorm_fwd_impl(x, nullptr, out, inv_out, p, tid, tg_id, tg_sum, &shared_inv); +} + +kernel void rmsnorm_fp32( + device const float* x [[ buffer(0) ]], + device const float* weight [[ buffer(1) ]], + device float* out [[ buffer(2) ]], + constant RMSNormParams& p [[ buffer(3) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + threadgroup float tg_sum[8]; + threadgroup float shared_inv; + rmsnorm_fwd_impl(x, weight, out, nullptr, p, tid, tg_id, tg_sum, &shared_inv); +} + +kernel void rmsnorm_fwd_inv_fp32( + device const float* x [[ buffer(0) ]], + device const float* weight [[ buffer(1) ]], + device float* out [[ buffer(2) ]], + device float* inv_out [[ buffer(3) ]], + constant RMSNormParams& p [[ buffer(4) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + threadgroup float tg_sum[8]; + threadgroup float shared_inv; + rmsnorm_fwd_impl(x, weight, out, inv_out, p, tid, tg_id, tg_sum, &shared_inv); +} + +kernel void rmsnorm_noweight_fp32( + device const float* x [[ buffer(0) ]], + device float* out [[ buffer(1) ]], + constant RMSNormParams& p [[ buffer(2) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + threadgroup float tg_sum[8]; + threadgroup float shared_inv; + rmsnorm_fwd_impl(x, nullptr, out, nullptr, p, tid, tg_id, tg_sum, &shared_inv); +} + +kernel void rmsnorm_noweight_fwd_inv_fp32( + device const float* x [[ buffer(0) ]], + device float* out [[ buffer(1) ]], + device float* inv_out [[ buffer(2) ]], + constant RMSNormParams& p [[ buffer(3) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + threadgroup float tg_sum[8]; + threadgroup float shared_inv; + rmsnorm_fwd_impl(x, nullptr, out, inv_out, p, tid, tg_id, tg_sum, &shared_inv); +} + diff --git a/src/heartlib/accelerators/metal/rmsnorm.py b/src/heartlib/accelerators/metal/rmsnorm.py new file mode 100644 index 0000000..b987088 --- /dev/null +++ b/src/heartlib/accelerators/metal/rmsnorm.py @@ -0,0 +1,110 @@ +"""Fused RMSNorm wrapper for the Metal extension.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +import torch + +from .runtime import metal_supported +from .jit import load_heartlib_metal_ops + +if TYPE_CHECKING: + from torch import Tensor + + +class _AutogradCtx(Protocol): + saved_tensors: tuple["Tensor", ...] + + def save_for_backward(self, *tensors: "Tensor") -> None: ... + + +def metal_rmsnorm_available() -> bool: + return metal_supported() + + +class _MetalRMSNormFn(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx: _AutogradCtx, + x: "Tensor", + weight: "Tensor | None", + eps: float, + verbose_build: bool, + ) -> "Tensor": + if x.device.type != "mps": + raise RuntimeError("Metal RMSNorm requires device.type == 'mps'") + if x.dtype not in (torch.float16, torch.float32): + raise RuntimeError("Metal RMSNorm supports fp16/fp32 only") + + x2 = x.contiguous() + ops = load_heartlib_metal_ops(verbose=bool(verbose_build)) + + if weight is None: + out, inv = ops.rmsnorm_noweight_forward_with_inv(x2, float(eps)) + ctx.save_for_backward(x2, inv) + return out + + w2 = weight.to(device=x.device, dtype=x.dtype).contiguous() + out, inv = ops.rmsnorm_forward_with_inv(x2, w2, float(eps)) + ctx.save_for_backward(x2, w2, inv) + return out + + @staticmethod + def backward( # type: ignore[override] + ctx: _AutogradCtx, + grad_out: "Tensor", + ) -> tuple["Tensor | None", ...]: + if grad_out is None: + raise RuntimeError("Metal RMSNorm backward requires grad_out") + if grad_out.device.type != "mps": + raise RuntimeError("Metal RMSNorm backward requires grad_out on MPS") + + saved = ctx.saved_tensors + target_dtype = saved[0].dtype + if grad_out.dtype != target_dtype: + grad_out = grad_out.to(dtype=target_dtype) + g = grad_out.contiguous() + + ops = load_heartlib_metal_ops(verbose=False) + if len(saved) == 2: + x, inv = saved + grad_x = ops.rmsnorm_backward_x_noweight(g, x, inv) + return (grad_x, None, None, None) + if len(saved) == 3: + x, w, inv = saved + grad_x = ops.rmsnorm_backward_x(g, x, w, inv) + grad_w = ops.rmsnorm_backward_w(g, x, inv) + return (grad_x, grad_w, None, None) + raise RuntimeError("Metal RMSNorm backward: invalid saved tensor state") + + +def rmsnorm_fp16( + *, + x: "Tensor", + weight: "Tensor | None", + eps: float = 1e-6, + verbose_build: bool = False, +) -> "Tensor": + """Fused RMSNorm (MPS/Metal) for fp16/fp32 tensors.""" + if x.device.type != "mps": + raise RuntimeError("Metal RMSNorm requires device.type == 'mps'") + if x.dtype not in (torch.float16, torch.float32): + raise RuntimeError("Metal RMSNorm supports fp16/fp32 only") + + needs_grad = bool(x.requires_grad) or ( + weight is not None and bool(weight.requires_grad) + ) + if not needs_grad: + x2 = x.contiguous() + ops = load_heartlib_metal_ops(verbose=bool(verbose_build)) + if weight is None: + return ops.rmsnorm_noweight(x2, float(eps)) + w2 = weight.to(device=x.device, dtype=x.dtype).contiguous() + return ops.rmsnorm(x2, w2, float(eps)) + + y = _MetalRMSNormFn.apply(x, weight, float(eps), bool(verbose_build)) + if not isinstance(y, torch.Tensor): + raise TypeError("Metal RMSNorm returned a non-tensor output") + return y + diff --git a/src/heartlib/accelerators/metal/rope.metal b/src/heartlib/accelerators/metal/rope.metal new file mode 100644 index 0000000..02cd585 --- /dev/null +++ b/src/heartlib/accelerators/metal/rope.metal @@ -0,0 +1,127 @@ +#include +using namespace metal; + +// Must match `RoPEParams` in `ops.mm` (layout + types). +struct RoPEParams { + uint d_model; + uint rot_dim; + uint half_rot; + uint seq_len; +}; + +template +inline void rope_impl( + device const T* x, + device const T* cos_t, + device const T* sin_t, + device T* out, + constant RoPEParams& p, + uint tid, + uint tg_id +) { + constexpr uint TG = 256; + const uint vec = tg_id; + const uint t = (p.seq_len > 0) ? (vec % p.seq_len) : 0; + + device const T* xr = x + vec * p.d_model; + device T* yr = out + vec * p.d_model; + + device const T* c = cos_t + t * p.half_rot; + device const T* s = sin_t + t * p.half_rot; + + for (uint i = tid; i < p.d_model; i += TG) { + if (i < p.half_rot) { + const float x1 = float(xr[i]); + const float x2 = float(xr[i + p.half_rot]); + const float cc = float(c[i]); + const float ss = float(s[i]); + yr[i] = T(x1 * cc - x2 * ss); + yr[i + p.half_rot] = T(x1 * ss + x2 * cc); + } else if (i >= p.rot_dim) { + yr[i] = xr[i]; + } + } +} + +template +inline void rope_bwd_impl( + device const T* grad_y, + device const T* cos_t, + device const T* sin_t, + device T* grad_x, + constant RoPEParams& p, + uint tid, + uint tg_id +) { + constexpr uint TG = 256; + const uint vec = tg_id; + const uint t = (p.seq_len > 0) ? (vec % p.seq_len) : 0; + + device const T* gr = grad_y + vec * p.d_model; + device T* gx = grad_x + vec * p.d_model; + + device const T* c = cos_t + t * p.half_rot; + device const T* s = sin_t + t * p.half_rot; + + for (uint i = tid; i < p.d_model; i += TG) { + if (i < p.half_rot) { + const float gy1 = float(gr[i]); + const float gy2 = float(gr[i + p.half_rot]); + const float cc = float(c[i]); + const float ss = float(s[i]); + gx[i] = T(gy1 * cc + gy2 * ss); + gx[i + p.half_rot] = T(-gy1 * ss + gy2 * cc); + } else if (i >= p.rot_dim) { + gx[i] = gr[i]; + } + } +} + +kernel void rope_fp16( + device const half* x [[ buffer(0) ]], + device const half* cos_t [[ buffer(1) ]], + device const half* sin_t [[ buffer(2) ]], + device half* out [[ buffer(3) ]], + constant RoPEParams& p [[ buffer(4) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + rope_impl(x, cos_t, sin_t, out, p, tid, tg_id); +} + +kernel void rope_fp32( + device const float* x [[ buffer(0) ]], + device const float* cos_t [[ buffer(1) ]], + device const float* sin_t [[ buffer(2) ]], + device float* out [[ buffer(3) ]], + constant RoPEParams& p [[ buffer(4) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + rope_impl(x, cos_t, sin_t, out, p, tid, tg_id); +} + +kernel void rope_bwd_fp16( + device const half* grad_y [[ buffer(0) ]], + device const half* cos_t [[ buffer(1) ]], + device const half* sin_t [[ buffer(2) ]], + device half* grad_x [[ buffer(3) ]], + constant RoPEParams& p [[ buffer(4) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + rope_bwd_impl(grad_y, cos_t, sin_t, grad_x, p, tid, tg_id); +} + +kernel void rope_bwd_fp32( + device const float* grad_y [[ buffer(0) ]], + device const float* cos_t [[ buffer(1) ]], + device const float* sin_t [[ buffer(2) ]], + device float* grad_x [[ buffer(3) ]], + constant RoPEParams& p [[ buffer(4) ]], + uint tid [[ thread_position_in_threadgroup ]], + uint tg_id [[ threadgroup_position_in_grid ]] +) { + rope_bwd_impl(grad_y, cos_t, sin_t, grad_x, p, tid, tg_id); +} + diff --git a/src/heartlib/accelerators/metal/rope.py b/src/heartlib/accelerators/metal/rope.py new file mode 100644 index 0000000..a9c343c --- /dev/null +++ b/src/heartlib/accelerators/metal/rope.py @@ -0,0 +1,102 @@ +"""Fused RoPE wrapper for the Metal extension.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +import torch + +from .runtime import metal_supported +from .jit import load_heartlib_metal_ops + +if TYPE_CHECKING: + from torch import Tensor + + +class _AutogradCtx(Protocol): + saved_tensors: tuple["Tensor", ...] + + def save_for_backward(self, *tensors: "Tensor") -> None: ... + + +def metal_rope_available() -> bool: + return metal_supported() + + +class _MetalRoPEFn(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx: _AutogradCtx, + x: "Tensor", + cos: "Tensor", + sin: "Tensor", + rot_dim: int, + verbose_build: bool, + ) -> "Tensor": + if x.device.type != "mps": + raise RuntimeError("Metal RoPE requires device.type == 'mps'") + if x.dtype not in (torch.float16, torch.float32): + raise RuntimeError("Metal RoPE supports fp16/fp32 only") + + x2 = x.contiguous() + cos2 = cos.to(device=x.device, dtype=x.dtype).contiguous() + sin2 = sin.to(device=x.device, dtype=x.dtype).contiguous() + + ops = load_heartlib_metal_ops(verbose=bool(verbose_build)) + ctx.save_for_backward(cos2, sin2) + ctx.rot_dim = int(rot_dim) # type: ignore[attr-defined] + return ops.rope(x2, cos2, sin2, int(rot_dim)) + + @staticmethod + def backward( # type: ignore[override] + ctx: _AutogradCtx, + grad_out: "Tensor", + ) -> tuple["Tensor | None", ...]: + if grad_out is None: + raise RuntimeError("Metal RoPE backward requires grad_out") + if grad_out.device.type != "mps": + raise RuntimeError("Metal RoPE backward requires grad_out on MPS") + + (cos, sin) = ctx.saved_tensors + target_dtype = cos.dtype + if grad_out.dtype != target_dtype: + grad_out = grad_out.to(dtype=target_dtype) + g = grad_out.contiguous() + + rot_dim = int(getattr(ctx, "rot_dim")) + ops = load_heartlib_metal_ops(verbose=False) + grad_x = ops.rope_backward(g, cos, sin, rot_dim) + return (grad_x, None, None, None, None) + + +def rope_fp16( + *, + x: Tensor, + cos: Tensor, + sin: Tensor, + rot_dim: int, + verbose_build: bool = False, +) -> Tensor: + """Apply RoPE using the Metal extension (fp16/fp32). + + Kernel expects: + - x: (B, H, T, D) + - cos/sin: (T, rot_dim/2) + """ + if x.device.type != "mps": + raise RuntimeError("Metal RoPE requires device.type == 'mps'") + if x.dtype not in (torch.float16, torch.float32): + raise RuntimeError("Metal RoPE supports fp16/fp32 only") + + if not bool(x.requires_grad): + x2 = x.contiguous() + cos2 = cos.to(device=x.device, dtype=x.dtype).contiguous() + sin2 = sin.to(device=x.device, dtype=x.dtype).contiguous() + ops = load_heartlib_metal_ops(verbose=bool(verbose_build)) + return ops.rope(x2, cos2, sin2, int(rot_dim)) + + y = _MetalRoPEFn.apply(x, cos, sin, int(rot_dim), bool(verbose_build)) + if not isinstance(y, torch.Tensor): + raise TypeError("Metal RoPE returned a non-tensor output") + return y + diff --git a/src/heartlib/accelerators/metal/runtime.py b/src/heartlib/accelerators/metal/runtime.py new file mode 100644 index 0000000..ebfb609 --- /dev/null +++ b/src/heartlib/accelerators/metal/runtime.py @@ -0,0 +1,51 @@ +"""Backend availability detection for Metal/MPS. + +This is a tiny, dependency-light helper used by the optional Metal fast path. +""" + +from __future__ import annotations + +import platform +import shutil +import subprocess +from typing import TYPE_CHECKING + +import torch + +__all__ = [ + "metal_supported", + "metal_build_tools_available", +] + + +def metal_supported() -> bool: + """Whether the current runtime *can* execute custom Metal (MPS) ops.""" + if TYPE_CHECKING: + return False + if platform.system() != "Darwin": + return False + try: + return bool(torch.backends.mps.is_available()) + except Exception: + return False + + +def metal_build_tools_available() -> bool: + """Whether the host can compile Metal shaders via Xcode toolchain.""" + if TYPE_CHECKING: + return False + if not metal_supported(): + return False + if shutil.which("xcrun") is None: + return False + try: + subprocess.check_output( + ["xcrun", "-sdk", "macosx", "--find", "metal"], stderr=subprocess.STDOUT + ) + subprocess.check_output( + ["xcrun", "-sdk", "macosx", "--find", "metallib"], stderr=subprocess.STDOUT + ) + except Exception: + return False + return True + diff --git a/src/heartlib/accelerators/torchtune_metal.py b/src/heartlib/accelerators/torchtune_metal.py index bf26b72..9c0e3fd 100644 --- a/src/heartlib/accelerators/torchtune_metal.py +++ b/src/heartlib/accelerators/torchtune_metal.py @@ -28,7 +28,7 @@ def __init__(self, *, scale: nn.Parameter, eps: float): # Lazy import: optimizer is an optional local package. self._metal_impl = None try: - from optimizer.metal.rmsnorm import rmsnorm_fp16 as _metal_rmsnorm + from .metal.rmsnorm import rmsnorm_fp16 as _metal_rmsnorm self._metal_impl = _metal_rmsnorm except Exception: @@ -69,7 +69,7 @@ def __init__(self, inner: nn.Module): self._metal_impl = None try: - from optimizer.metal.rope import rope_fp16 as _metal_rope + from .metal.rope import rope_fp16 as _metal_rope self._metal_impl = _metal_rope except Exception: From 76796e34565f260d013bd5ec9dd944dba1b42484 Mon Sep 17 00:00:00 2001 From: Daniel Owen van Dommelen Date: Mon, 19 Jan 2026 22:44:33 +0100 Subject: [PATCH 6/6] Update `pyproject.toml` to restrict Python version compatibility and add optional dependencies for MuQ-MuLan. Modify `README.md` to reflect new Python version recommendations and installation instructions for optional features. Enhance `run_music_generation.py` and `HeartMuLaGenPipeline` to support reference audio conditioning and auto-download of MuQ-MuLan, improving music generation capabilities. --- README.md | 14 +- examples/run_music_generation.py | 57 ++++++ pyproject.toml | 8 +- src/heartlib/pipelines/music_generation.py | 207 ++++++++++++++++++++- 4 files changed, 279 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 314e205..1dfdfd8 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ Our latest internal version of HeartMuLa-7B achieves **comparable performance wi ### ⚙️ Environment Setup -We recommend using `python=3.10` for local deployment. +We recommend using **Python 3.10–3.12** for local deployment (newer versions like 3.14 may not have prebuilt wheels for key deps and will try to compile from source). Clone this repo and install locally. @@ -72,6 +72,12 @@ cd heartlib pip install -e . ``` +Optional (recommended if you want **reference-audio conditioning** via MuQ-MuLan): + +``` +pip install -e ".[muq]" +``` + Download our pretrained checkpoints from huggingface or modelscope using the following command: ``` @@ -104,6 +110,12 @@ To generate music, run: python ./examples/run_music_generation.py --model_path=./ckpt --version="3B" ``` +To enable **reference-audio conditioning** (auto-download MuQ-MuLan from Hugging Face): + +``` +python ./examples/run_music_generation.py --model_path=./ckpt --version="3B" --load_muq_mulan --ref_audio /path/to/ref.wav +``` + By default this command will generate a piece of music conditioned on lyrics and tags provided in `./assets` folder. The output music will be saved at `./assets/output.mp3`. All parameters: diff --git a/examples/run_music_generation.py b/examples/run_music_generation.py index 710e382..1134ac0 100644 --- a/examples/run_music_generation.py +++ b/examples/run_music_generation.py @@ -11,7 +11,54 @@ def parse_args(): parser.add_argument("--version", type=str, default="3B") parser.add_argument("--lyrics", type=str, default="./assets/lyrics.txt") parser.add_argument("--tags", type=str, default="./assets/tags.txt") + parser.add_argument( + "--ref_audio", + type=str, + default=None, + help="Optional: path to reference audio for MuQ-MuLan conditioning.", + ) + parser.add_argument( + "--load_muq_mulan", + action="store_true", + help="Auto-download/load MuQ-MuLan from Hugging Face (requires `pip install muq`).", + ) + parser.add_argument( + "--muq_model_id", + type=str, + default="OpenMuQ/MuQ-MuLan-large", + help="Hugging Face model id for MuQ-MuLan.", + ) + parser.add_argument( + "--muq_cache_dir", + type=str, + default=None, + help="Optional: Hugging Face cache dir for MuQ-MuLan.", + ) + parser.add_argument( + "--muq_revision", + type=str, + default=None, + help="Optional: Hugging Face revision (branch/tag/commit) for MuQ-MuLan.", + ) + parser.add_argument( + "--muq_segment_sec", + type=float, + default=10.0, + help="Reference-audio segment length (seconds) fed to MuQ.", + ) + parser.add_argument( + "--muq_sample_rate", + type=int, + default=24000, + help="Sample rate expected by MuQ (usually 24 kHz).", + ) parser.add_argument("--save_path", type=str, default="./assets/output.mp3") + parser.add_argument( + "--codes_path", + type=str, + default=None, + help="Optional: save generated audio token frames (torch .pt) for analysis.", + ) parser.add_argument("--max_audio_length_ms", type=int, default=240_000) parser.add_argument("--topk", type=int, default=50) @@ -39,17 +86,27 @@ def parse_args(): device=device, dtype=dtype, version=args.version, + load_muq_mulan=args.load_muq_mulan, + muq_model_id=args.muq_model_id, + muq_cache_dir=args.muq_cache_dir, + muq_revision=args.muq_revision, ) with torch.no_grad(): pipe( { "lyrics": args.lyrics, "tags": args.tags, + "ref_audio": args.ref_audio, + "muq_segment_sec": args.muq_segment_sec, + "muq_sample_rate": args.muq_sample_rate, }, max_audio_length_ms=args.max_audio_length_ms, save_path=args.save_path, + codes_path=args.codes_path, topk=args.topk, temperature=args.temperature, cfg_scale=args.cfg_scale, ) print(f"Generated music saved to {args.save_path}") + if args.codes_path: + print(f"Saved audio token frames to {args.codes_path}") diff --git a/pyproject.toml b/pyproject.toml index 140fe04..70e15a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,9 @@ name = "heartlib" version = "0.1.0" description = "A Python Library." readme = "README.md" -requires-python = ">=3.9" +# Torch/torchaudio and many scientific wheels may not be available on bleeding-edge +# Python versions (e.g. 3.14), which forces fragile source builds. +requires-python = ">=3.9,<3.13" license = {text = "CC-BY-NC-4.0"} authors = [ {name = "HeartMuLa Team", email = "heartmula.ai@gmail.com"} @@ -39,6 +41,10 @@ classifiers = [ "Operating System :: OS Independent" ] +[project.optional-dependencies] +# Optional: enables auto-download + inference for MuQ-MuLan reference-audio conditioning. +muq = ["muq"] + [tool.setuptools] package-dir = {"" = "src"} diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index 3f9d3b3..ed421ef 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -6,6 +6,7 @@ import torch import torchaudio +import torch.nn.functional as F from tokenizers import Tokenizer from tqdm import tqdm from transformers import BitsAndBytesConfig @@ -75,6 +76,7 @@ def _sanitize_parameters(self, **kwargs): } postprocess_kwargs = { "save_path": kwargs.get("save_path", "output.mp3"), + "codes_path": kwargs.get("codes_path", None), } return preprocess_kwargs, forward_kwargs, postprocess_kwargs @@ -99,11 +101,157 @@ def preprocess(self, input_: Dict[str, Any], **preprocess_parameters: Any): if tags_ids[-1] != self.config.text_eos_id: tags_ids = tags_ids + [self.config.text_eos_id] + def _load_ref_audio(ref: Any) -> tuple[torch.Tensor, int]: + """ + Returns (mono_waveform, sample_rate) where mono_waveform is 1D [T]. + """ + if isinstance(ref, str): + wav, sr = torchaudio.load(ref) + elif isinstance(ref, torch.Tensor): + wav = ref + sr = int(input_.get("ref_audio_sr", 0) or 0) + if sr <= 0: + raise ValueError( + "ref_audio was provided as a Tensor but `ref_audio_sr` was missing/invalid." + ) + else: + raise TypeError( + f"ref_audio must be a file path or torch.Tensor, got {type(ref)}" + ) + + # Accept [T], [C,T], or [B,C,T] (take the first batch). + if wav.ndim == 3: + wav = wav[0] + if wav.ndim == 2: + wav = wav.mean(dim=0) + elif wav.ndim != 1: + raise ValueError(f"Unsupported ref_audio tensor shape: {tuple(wav.shape)}") + + wav = wav.to(dtype=torch.float32) + return wav, int(sr) + + def _prepare_muq_audio(wav: torch.Tensor, sr: int) -> torch.Tensor: + """ + Resample to MuQ sample rate (default 24k) and take/pad a ~10s segment. + Returns waveform shaped [1, T] on self._device. + """ + muq_sr = int(input_.get("muq_sample_rate", 24_000)) + seg_s = float(input_.get("muq_segment_sec", 10.0)) + seg_len = max(1, int(round(muq_sr * seg_s))) + + if sr != muq_sr: + wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=muq_sr) + + if wav.numel() >= seg_len: + start = (wav.numel() - seg_len) // 2 + wav = wav[start : start + seg_len] + else: + wav = F.pad(wav, (0, seg_len - wav.numel())) + + # Common MuQ-style encoders expect [B, T]. + return wav.unsqueeze(0).to(device=self._device) + + def _run_muq_mulan(audio_bt: torch.Tensor, sample_rate: int) -> torch.Tensor: + """ + Runs the provided MuQ-MuLan model and returns a 1D [muq_dim] embedding. + Tries a few common APIs / output layouts. + """ + if self.muq_mulan is None: + raise ValueError( + "ref_audio was provided but `muq_mulan` is None. " + "Pass a pretrained MuQ-MuLan model to HeartMuLaGenPipeline." + ) + + model = self.muq_mulan + was_training = getattr(model, "training", False) + if hasattr(model, "eval"): + model.eval() + + with torch.inference_mode(): + out = None + # Common: model.encode_audio(audio, sample_rate=...) + if hasattr(model, "encode_audio") and callable(getattr(model, "encode_audio")): + try: + out = model.encode_audio(audio_bt, sample_rate=sample_rate) + except TypeError: + out = model.encode_audio(audio_bt) + # Fallback: callable model(audio, sample_rate=...) + if out is None and callable(model): + try: + out = model(audio_bt, sample_rate=sample_rate) + except TypeError: + out = model(audio_bt) + + if was_training and hasattr(model, "train"): + model.train() + + def _to_tensor(x: Any) -> Optional[torch.Tensor]: + if x is None: + return None + if isinstance(x, torch.Tensor): + return x + if isinstance(x, (tuple, list)) and x: + return _to_tensor(x[0]) + if isinstance(x, (dict, ModelOutput)): + for k in ( + "joint_embedding", + "joint_embeds", + "embedding", + "embeddings", + "audio_embedding", + "audio_embeds", + "audio_embed", + "audio_features", + "audio_feature", + ): + if k in x: + return _to_tensor(x[k]) + for attr in ( + "joint_embedding", + "embedding", + "embeddings", + "audio_embedding", + "audio_embeds", + "audio_features", + ): + if hasattr(x, attr): + return _to_tensor(getattr(x, attr)) + return None + + emb = _to_tensor(out) + if emb is None: + raise ValueError( + "Could not extract an embedding from `muq_mulan` output. " + "Expected a Tensor or a dict/ModelOutput with an embedding field." + ) + + # Accept [D], [1,D], or [B,D] (take first). + emb = emb.detach() + if emb.ndim == 2: + emb = emb[0] + elif emb.ndim != 1: + raise ValueError(f"Unsupported muq embedding shape: {tuple(emb.shape)}") + + if emb.numel() != self._muq_dim: + raise ValueError( + f"MuQ-MuLan embedding dim mismatch: expected {self._muq_dim}, got {emb.numel()}." + ) + + # Normalize is common for joint embeddings; safe and improves conditioning stability. + emb = emb / (emb.norm(p=2) + 1e-12) + return emb.to(device="cpu", dtype=self.dtype) + ref_audio = input_.get("ref_audio", None) if ref_audio is not None: - raise NotImplementedError("ref_audio is not supported yet.") - muq_embed = torch.zeros([self._muq_dim], dtype=self.dtype) - muq_idx = len(tags) + wav, sr = _load_ref_audio(ref_audio) + muq_sr = int(input_.get("muq_sample_rate", 24_000)) + audio_bt = _prepare_muq_audio(wav, sr) + muq_embed = _run_muq_mulan(audio_bt, sample_rate=muq_sr) + else: + muq_embed = torch.zeros([self._muq_dim], dtype=self.dtype) + + # The reserved slot is the blank "+1" token after tags_ids. + muq_idx = len(tags_ids) lyrics = input_["lyrics"] if os.path.isfile(lyrics): @@ -239,14 +387,24 @@ def _autocast_ctx(): frames = frame_buf[:frame_len].transpose(0, 1).contiguous() wav = self.audio_codec.detokenize(frames) - return ModelOutput(wav=wav) + # Include tokens in the output so postprocess can optionally persist them. + # This is opt-in (see postprocess `codes_path`) and does not change default behavior. + return ModelOutput(wav=wav, codes=frames.detach().cpu()) def postprocess( self, model_outputs: ModelOutput, **postprocess_parameters: Any ) -> None: save_path: str = postprocess_parameters.get("save_path", "output.mp3") + codes_path: Optional[str] = postprocess_parameters.get("codes_path", None) wav = model_outputs["wav"] torchaudio.save(save_path, wav, 48000) + if codes_path: + codes = model_outputs.get("codes", None) + if codes is None: + raise ValueError( + "codes_path was provided but no `codes` were found in model outputs." + ) + torch.save(codes, codes_path) @classmethod def from_pretrained( @@ -256,6 +414,11 @@ def from_pretrained( dtype: torch.dtype, version: str, bnb_config: Optional[BitsAndBytesConfig] = None, + *, + load_muq_mulan: bool = False, + muq_model_id: Optional[str] = None, + muq_cache_dir: Optional[str] = None, + muq_revision: Optional[str] = None, ): if os.path.exists( heartcodec_path := os.path.join(pretrained_path, "HeartCodec-oss") @@ -295,4 +458,38 @@ def from_pretrained( f"Expected to find gen_config.json for HeartMuLa at {gen_config_path} but not found. Please check your folder {pretrained_path}." ) - return cls(heartmula, heartcodec, None, tokenizer, gen_config, device, dtype) + # Optional: load MuQ-MuLan from Hugging Face (auto-download + cache). + # Enable via argument or env HEARTLIB_LOAD_MUQ_MULAN=1. + if not load_muq_mulan: + load_muq_mulan = os.getenv("HEARTLIB_LOAD_MUQ_MULAN", "0") == "1" + + muq_mulan = None + if load_muq_mulan: + model_id = ( + muq_model_id + or os.getenv("HEARTLIB_MUQ_MULAN_ID", "").strip() + or "OpenMuQ/MuQ-MuLan-large" + ) + try: + # MuQ's own library wraps Hugging Face download via .from_pretrained(). + # Install: pip install muq + from muq import MuQMuLan # type: ignore + except Exception as e: # pragma: no cover + raise ImportError( + "MuQ-MuLan auto-download requested, but the `muq` package is not installed. " + "Install it with: pip install muq" + ) from e + + kwargs: Dict[str, Any] = {} + if muq_cache_dir is not None: + kwargs["cache_dir"] = muq_cache_dir + if muq_revision is not None: + kwargs["revision"] = muq_revision + + muq_mulan = MuQMuLan.from_pretrained(model_id, **kwargs) + if hasattr(muq_mulan, "to"): + muq_mulan = muq_mulan.to(device) + if hasattr(muq_mulan, "eval"): + muq_mulan.eval() + + return cls(heartmula, heartcodec, muq_mulan, tokenizer, gen_config, device, dtype)