From 8ce115fa8346efc2ffbbf3aad0002e434e1447f1 Mon Sep 17 00:00:00 2001 From: Jaeyeon Kim <0310kjy@gmail.com> Date: Tue, 11 Feb 2025 13:46:47 +0900 Subject: [PATCH] Update evaluate_efficiency_salmonn.py --- evaluate_efficiency_salmonn.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/evaluate_efficiency_salmonn.py b/evaluate_efficiency_salmonn.py index 12422a8..54065d5 100644 --- a/evaluate_efficiency_salmonn.py +++ b/evaluate_efficiency_salmonn.py @@ -106,9 +106,25 @@ def get_gpu_memory_usage(): return gpu_memory -def model_inference(cfg, samples, test_prompt, salmonn): +def audio_preprocess(cfg, wav_processor, audio, sr=16000): + spectrogram = wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"] + audio = audio.unsqueeze(0) + sample_batch = { + "spectrogram": spectrogram, + "raw_wav": audio, + "text": ["test"], + "task": ["asr"], + "Q": [""], + "id": [0], + } + sample_batch = prepare_sample(sample_batch, cuda_enabled=torch.cuda.is_available(), device=cfg.config.run.device) + return sample_batch + + +def model_inference(cfg, audio, test_prompt, salmonn, wav_processor): # TTFT start_time = time.time() + samples = audio_preprocess(cfg, wav_processor, samples) llm = salmonn.llama_model batch_size = samples["spectrogram"].shape[0] @@ -181,9 +197,13 @@ def main(args): # Load dataset with open("audiolm-trainer/prompts/test_prompt.json", "r") as f: test_prompt = json.load(f) - dataloader = MockDataset.make_mock_dataloader(cfg, sr=16000, audio_length=10) - sample_batch = next(iter(dataloader)) - sample_batch = prepare_sample(sample_batch, cuda_enabled=torch.cuda.is_available()) + # dataloader = MockDataset.make_mock_dataloader(cfg, sr=16000, audio_length=10) + # sample_batch = next(iter(dataloader)) + # sample_batch = prepare_sample(sample_batch, cuda_enabled=torch.cuda.is_available()) + sr=16000 + audio_length=10 + wav_processor = WhisperFeatureExtractor.from_pretrained(cfg.config.model['whisper_path']) + random_sample = torch.randn(sr * audio_length) # Measure memory and latency memory_usages = [] @@ -196,9 +216,10 @@ def main(args): with torch.no_grad(): inference_time, ttft, tpot = model_inference( cfg, - sample_batch, + random_sample, test_prompt, salmonn_preprocessor, + wav_processor ) torch.cuda.synchronize() after_memory_allocated = torch.cuda.max_memory_allocated()