Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ messages_asr = [
# You can provide context or instructions as text
{"role": "user", "message_type": "text", "content": "Please transcribe the following audio:"},
# Provide the audio file path
# Note that the content can also be waveform in 2D torch.Tensor format with sr=16000 (the first dim size is 1, i.e., shape = [1, sample_length])
{"role": "user", "message_type": "audio", "content": "test_audios/asr_example.wav"}
]

Expand Down
20 changes: 12 additions & 8 deletions kimia_infer/api/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def _tokenize_text(self, text):
token_ids = self.text_tokenizer.encode(text, bos=False, eos=False)
return token_ids

def _tokenize_audio(self, wav_path):
wav_tokens = self.audio_tokenizer.tokenize(audio_path=wav_path)
def _tokenize_audio(self, wav_path_or_waveform):
if isinstance(wav_path_or_waveform, str):
wav_tokens = self.audio_tokenizer.tokenize(audio_path=wav_path_or_waveform)
else:
wav_tokens = self.audio_tokenizer.tokenize(speech=wav_path_or_waveform)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种情况下好像没办法保证wav_path_or_waveform一定sr=16000,如果用错了会有难发现的bug(能正常推但是结果不对)

感觉可以来个org_sr,当wav_path_or_waveform是一个ndarray / tensor的时候,要求同时提供一下这个wavform对应的sr,这样我们可以里面resample一下?

wav_tokens = wav_tokens + self.kimia_token_offset
wav_tokens_list = wav_tokens.squeeze(0).cpu().numpy().tolist()
return wav_tokens_list
Expand All @@ -63,8 +66,9 @@ def extract_whisper_feat(self, wav: torch.Tensor | str):
wav_tensor = torch.tensor(wav).unsqueeze(0)[:, :]
elif isinstance(wav, torch.Tensor):
wav_tensor = wav
assert len(wav_tensor.shape) == 2, "The wav tensor must be a 2D tensor"
else:
raise ValueError(f"Invalid wav type: {type(wav)}")
raise ValueError(f"Invalid wav type: {type(wav)}, wav must be a string or a 2-D torch.Tensor")
assert self.whisper_model is not None
wav_tensor = wav_tensor.to(torch.cuda.current_device())
continous_feature = self.whisper_model.tokenize_waveform(wav_tensor)
Expand Down Expand Up @@ -116,11 +120,11 @@ def tokenize_message(
kimia_content_msg.audio_append(self.extra_tokens.kimia_text_blank, audio_token_loss_mask=False)

elif message["message_type"] == "audio":
audio_path_or_waveform = message["content"]
if "audio_tokens" in message:
speech_tokens = message["audio_tokens"]
else:
audio_path = message["content"]
speech_tokens = self._tokenize_audio(audio_path)
speech_tokens = self._tokenize_audio(audio_path_or_waveform)

kimia_content_msg.audio_append(self.extra_tokens.media_begin)
kimia_content_msg.audio_extend(speech_tokens, is_continuous=True, audio_token_loss_mask=has_loss)
Expand All @@ -139,11 +143,11 @@ def tokenize_message(
kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank)

if extract_whisper_feature:
whisper_feature = self.extract_whisper_feat(audio_path)
whisper_feature = self.extract_whisper_feat(audio_path_or_waveform)
kimia_content_msg.continuous_feature.append(whisper_feature)
elif message["message_type"] == "audio-text":
audio_path, text = message["content"]
speech_tokens = self._tokenize_audio(audio_path)
audio_path_or_waveform, text = message["content"]
speech_tokens = self._tokenize_audio(audio_path_or_waveform)
text_tokens = self._tokenize_text(text)

kimia_content_msg.audio_extend([self.extra_tokens.kimia_text_blank] * self.kimia_text_audiodelaytokens)
Expand Down