1+ # Copyright (c) 2025 ASLP-LAB
2+ # 2025 Huakang Chen (huakang@mail.nwpu.edu.cn)
3+ # 2025 Guobin Ma (guobin.ma@gmail.com)
4+ #
5+ # Licensed under the Stability AI License (the "License");
6+ # you may not use this file except in compliance with the License.
7+ # You may obtain a copy of the License at
8+ #
9+ # https://huggingface.co/stabilityai/stable-audio-open-1.0/blob/main/LICENSE.md
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+ # limitations under the License.
16+
17+ import torchaudio
18+ import librosa
19+ from mutagen .mp3 import MP3
20+ import torch
21+ from einops import rearrange
22+
23+ from diffrhythm_utils import (
24+ decode_audio ,
25+ get_lrc_token ,
26+ get_negative_style_prompt ,
27+ get_reference_latent ,
28+ prepare_model ,
29+ )
30+
31+
32+ def inference (
33+ cfm_model ,
34+ vae_model ,
35+ cond ,
36+ text ,
37+ duration ,
38+ style_prompt ,
39+ negative_style_prompt ,
40+ start_time ,
41+ chunked = False ,
42+ ):
43+ with torch .inference_mode ():
44+ generated , _ = cfm_model .sample (
45+ cond = cond ,
46+ text = text ,
47+ duration = duration ,
48+ style_prompt = style_prompt ,
49+ negative_style_prompt = negative_style_prompt ,
50+ steps = 32 ,
51+ cfg_strength = 4.0 ,
52+ start_time = start_time ,
53+ )
54+
55+
56+ generated = generated .to (torch .float32 )
57+ latent = generated .transpose (1 , 2 ) # [b d t]
58+
59+ output = decode_audio (latent , vae_model , chunked = chunked )
60+
61+ # Rearrange audio batch to a single sequence
62+ output = rearrange (output , "b d n -> d (b n)" )
63+ # Peak normalize, clip, convert to int16, and save to file
64+ output = (
65+ output .to (torch .float32 )
66+ .div (torch .max (torch .abs (output )))
67+ .clamp (- 1 , 1 )
68+ .mul (32767 )
69+ .to (torch .int16 )
70+ .cpu ()
71+ )
72+
73+ return output
74+
75+
76+ class MultiLinePrompt :
77+ @classmethod
78+ def INPUT_TYPES (cls ):
79+
80+ return {
81+ "required" : {
82+ "multi_line_prompt" : ("STRING" , {
83+ "multiline" : True ,
84+ "default" : "" }),
85+ },
86+ }
87+
88+ CATEGORY = "MW-DiffRhythm"
89+ RETURN_TYPES = ("STRING" ,)
90+ RETURN_NAMES = ("prompt" ,)
91+ FUNCTION = "promptgen"
92+
93+ def promptgen (self , multi_line_prompt : str ):
94+ return (multi_line_prompt .strip (),)
95+
96+
97+ class DiffRhythmRun :
98+ device = "cpu"
99+ if torch .cuda .is_available ():
100+ device = "cuda"
101+ elif torch .mps .is_available ():
102+ device = "mps"
103+
104+ @classmethod
105+ def INPUT_TYPES (cls ):
106+
107+ return {
108+ "required" : {
109+ "style_prompt" : ("STRING" , {
110+ "multiline" : True ,
111+ "default" : "" }),
112+ },
113+ "optional" : {
114+ "lyrics_prompt" : ("STRING" ,),
115+ "style_audio" : ("AUDIO" , ),
116+ "chunked" : ("BOOLEAN" , {"default" : False , "tooltip" : "Whether to use chunked decoding." }),
117+ "seed" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 0xFFFFFFFFFFFFFFFF }),
118+ },
119+ }
120+
121+ CATEGORY = "MW-DiffRhythm"
122+ RETURN_TYPES = ("AUDIO" ,)
123+ RETURN_NAMES = ("audio" ,)
124+ FUNCTION = "diffrhythmgen"
125+
126+ def diffrhythmgen (
127+ self ,
128+ style_prompt : str ,
129+ # audio_length: int,
130+ lyrics_prompt : str = "" ,
131+ style_audio : str = None ,
132+ chunked : bool = False ,
133+ seed : int = 0 ):
134+
135+ # if audio_length == 95:
136+ # max_frames = 2048
137+ # elif audio_length == 285: # current not available
138+ # max_frames = 6144
139+ max_frames = 2048
140+ cfm , tokenizer , muq , vae = prepare_model (self .device )
141+
142+ lrc_prompt , start_time = get_lrc_token (lyrics_prompt , tokenizer , self .device )
143+
144+ if style_audio :
145+ prompt = self .get_style_prompt (muq , style_audio )
146+ else :
147+ prompt = self .get_style_prompt (muq , prompt = style_prompt )
148+
149+ negative_style_prompt = get_negative_style_prompt (self .device )
150+ latent_prompt = get_reference_latent (self .device , max_frames )
151+
152+ try :
153+ generated_song = inference (
154+ cfm_model = cfm ,
155+ vae_model = vae ,
156+ cond = latent_prompt ,
157+ text = lrc_prompt ,
158+ duration = max_frames ,
159+ style_prompt = prompt ,
160+ negative_style_prompt = negative_style_prompt ,
161+ start_time = start_time ,
162+ chunked = chunked ,
163+ )
164+ except Exception as e :
165+ raise
166+
167+ audio_tensor = generated_song .unsqueeze (0 )
168+ return ({"waveform" : audio_tensor , "sample_rate" : 44100 },)
169+
170+ @torch .no_grad ()
171+ def get_style_prompt (self , model , audio = None , prompt = None ):
172+ mulan = model
173+
174+ if prompt is not None :
175+ return mulan (texts = prompt ).half ()
176+
177+ if audio is None :
178+ raise ValueError ("Audio data or style prompt must be provided" )
179+
180+ waveform = audio ["waveform" ]
181+ sample_rate = audio ["sample_rate" ]
182+
183+ # 确保波形是正确的形状
184+ if len (waveform .shape ) == 3 : # [1, channels, samples]
185+ waveform = waveform .squeeze (0 )
186+ if waveform .shape [0 ] > 1 : # 如果是立体声,转换为单声道
187+ waveform = waveform .mean (0 , keepdim = True )
188+
189+ # 计算音频长度(秒)
190+ audio_len = waveform .shape [- 1 ] / sample_rate
191+
192+ if audio_len < 10 :
193+ raise ValueError (f"The audio is too short ({ audio_len :.2f} s), it takes at least 10 seconds." )
194+
195+ # 提取中间 10 秒的片段
196+ mid_time = int ((audio_len // 2 ) * sample_rate )
197+ start_sample = mid_time - int (5 * sample_rate )
198+ end_sample = start_sample + int (10 * sample_rate )
199+ wav_segment = waveform [..., start_sample :end_sample ]
200+
201+ # 重采样到 24kHz
202+ if sample_rate != 24000 :
203+ wav_segment = torchaudio .transforms .Resample (sample_rate , 24000 )(wav_segment )
204+
205+ # 确保形状正确并移动到正确的设备
206+ wav = wav_segment .to (model .device )
207+ if len (wav .shape ) == 1 :
208+ wav = wav .unsqueeze (0 )
209+
210+ with torch .no_grad ():
211+ audio_emb = mulan (wavs = wav ) # [1, 512]
212+
213+ audio_emb = audio_emb .half ()
214+
215+ return audio_emb
216+
217+
218+ NODE_CLASS_MAPPINGS = {
219+ "DiffRhythmRun" : DiffRhythmRun ,
220+ "MultiLinePrompt" : MultiLinePrompt ,
221+ }
222+
223+ NODE_DISPLAY_NAME_MAPPINGS = {
224+ "DiffRhythmRun" : "DiffRhythm Run" ,
225+ "MultiLinePrompt" : "Multi Line Prompt" ,
226+ }
0 commit comments