@@ -122,7 +122,7 @@ def INPUT_TYPES(cls):
122122 },
123123 }
124124
125- CATEGORY = "MW/MW-DiffRhythm"
125+ CATEGORY = "🎤 MW/MW-DiffRhythm"
126126 RETURN_TYPES = ("STRING" ,)
127127 RETURN_NAMES = ("prompt" ,)
128128 FUNCTION = "promptgen"
@@ -166,7 +166,7 @@ def INPUT_TYPES(cls):
166166 },
167167 }
168168
169- CATEGORY = "MW/MW-DiffRhythm"
169+ CATEGORY = "🎤 MW/MW-DiffRhythm"
170170 RETURN_TYPES = ("AUDIO" ,)
171171 RETURN_NAMES = ("audio" ,)
172172 FUNCTION = "diffrhythmgen"
@@ -295,52 +295,6 @@ def get_text_style_prompt(self, model, text_prompt):
295295
296296 return text_emb
297297
298-
299- # @torch.no_grad()
300- # def get_style_prompt(self, model, audio=None, prompt=None):
301- # mulan = model
302-
303- # if prompt is not None:
304- # return mulan(texts=prompt).half()
305-
306- # if audio is None:
307- # raise ValueError("Audio data or style prompt must be provided")
308-
309- # waveform = audio["waveform"]
310- # sample_rate = audio["sample_rate"]
311-
312- # # Ensure waveform has correct shape
313- # if len(waveform.shape) == 3: # [1, channels, samples]
314- # waveform = waveform.squeeze(0)
315- # if waveform.shape[0] > 1: # If stereo, convert to mono
316- # waveform = waveform.mean(0, keepdim=True)
317-
318- # # Calculate audio length (seconds)
319- # audio_len = waveform.shape[-1] / sample_rate
320-
321- # if audio_len < 10:
322- # raise ValueError(f"Audio too short ({audio_len:.2f}s), minimum 10 seconds required.")
323-
324- # # Extract middle 10-second segment
325- # mid_time = int((audio_len // 2) * sample_rate)
326- # start_sample = mid_time - int(5 * sample_rate)
327- # end_sample = start_sample + int(10 * sample_rate)
328- # wav_segment = waveform[..., start_sample:end_sample]
329-
330- # # Resample to 24kHz
331- # if sample_rate != 24000:
332- # wav_segment = torchaudio.transforms.Resample(sample_rate, 24000)(wav_segment)
333-
334- # # Ensure correct shape and device
335- # wav = wav_segment.to(model.device)
336- # if len(wav.shape) == 1:
337- # wav = wav.unsqueeze(0)
338-
339- # with torch.no_grad():
340- # audio_emb = mulan(wavs=wav) # [1, 512]
341-
342- # audio_emb = audio_emb.half()
343- # return audio_emb
344298
345299 def prepare_model (self , model , device , unload_model = False ):
346300 # prepare tokenizer
0 commit comments