From 43b40202e3d78f567d7131d43486272622687570 Mon Sep 17 00:00:00 2001 From: YichongLeng Date: Sat, 21 Jun 2025 23:30:16 +0800 Subject: [PATCH] support input waveform --- README.md | 1 + kimia_infer/api/prompt_manager.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 3ace831..694f655 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ messages_asr = [ # You can provide context or instructions as text {"role": "user", "message_type": "text", "content": "Please transcribe the following audio:"}, # Provide the audio file path + # Note that the content can also be waveform in 2D torch.Tensor format with sr=16000 (the first dim size is 1, i.e., shape = [1, sample_length]) {"role": "user", "message_type": "audio", "content": "test_audios/asr_example.wav"} ] diff --git a/kimia_infer/api/prompt_manager.py b/kimia_infer/api/prompt_manager.py index 80fd78e..80aa7a9 100644 --- a/kimia_infer/api/prompt_manager.py +++ b/kimia_infer/api/prompt_manager.py @@ -50,8 +50,11 @@ def _tokenize_text(self, text): token_ids = self.text_tokenizer.encode(text, bos=False, eos=False) return token_ids - def _tokenize_audio(self, wav_path): - wav_tokens = self.audio_tokenizer.tokenize(audio_path=wav_path) + def _tokenize_audio(self, wav_path_or_waveform): + if isinstance(wav_path_or_waveform, str): + wav_tokens = self.audio_tokenizer.tokenize(audio_path=wav_path_or_waveform) + else: + wav_tokens = self.audio_tokenizer.tokenize(speech=wav_path_or_waveform) wav_tokens = wav_tokens + self.kimia_token_offset wav_tokens_list = wav_tokens.squeeze(0).cpu().numpy().tolist() return wav_tokens_list @@ -63,8 +66,9 @@ def extract_whisper_feat(self, wav: torch.Tensor | str): wav_tensor = torch.tensor(wav).unsqueeze(0)[:, :] elif isinstance(wav, torch.Tensor): wav_tensor = wav + assert len(wav_tensor.shape) == 2, "The wav tensor must be a 2D tensor" else: - raise ValueError(f"Invalid wav type: {type(wav)}") + raise ValueError(f"Invalid wav type: {type(wav)}, wav must be a string or a 2-D torch.Tensor") assert self.whisper_model is not None wav_tensor = wav_tensor.to(torch.cuda.current_device()) continous_feature = self.whisper_model.tokenize_waveform(wav_tensor) @@ -116,11 +120,11 @@ def tokenize_message( kimia_content_msg.audio_append(self.extra_tokens.kimia_text_blank, audio_token_loss_mask=False) elif message["message_type"] == "audio": + audio_path_or_waveform = message["content"] if "audio_tokens" in message: speech_tokens = message["audio_tokens"] else: - audio_path = message["content"] - speech_tokens = self._tokenize_audio(audio_path) + speech_tokens = self._tokenize_audio(audio_path_or_waveform) kimia_content_msg.audio_append(self.extra_tokens.media_begin) kimia_content_msg.audio_extend(speech_tokens, is_continuous=True, audio_token_loss_mask=has_loss) @@ -139,11 +143,11 @@ def tokenize_message( kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) if extract_whisper_feature: - whisper_feature = self.extract_whisper_feat(audio_path) + whisper_feature = self.extract_whisper_feat(audio_path_or_waveform) kimia_content_msg.continuous_feature.append(whisper_feature) elif message["message_type"] == "audio-text": - audio_path, text = message["content"] - speech_tokens = self._tokenize_audio(audio_path) + audio_path_or_waveform, text = message["content"] + speech_tokens = self._tokenize_audio(audio_path_or_waveform) text_tokens = self._tokenize_text(text) kimia_content_msg.audio_extend([self.extra_tokens.kimia_text_blank] * self.kimia_text_audiodelaytokens)