diff --git a/evaluate_efficiency_salmonn.py b/evaluate_efficiency_salmonn.py index 12422a8..54065d5 100644 --- a/evaluate_efficiency_salmonn.py +++ b/evaluate_efficiency_salmonn.py @@ -106,9 +106,25 @@ def get_gpu_memory_usage(): return gpu_memory -def model_inference(cfg, samples, test_prompt, salmonn): +def audio_preprocess(cfg, wav_processor, audio, sr=16000): + spectrogram = wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"] + audio = audio.unsqueeze(0) + sample_batch = { + "spectrogram": spectrogram, + "raw_wav": audio, + "text": ["test"], + "task": ["asr"], + "Q": [""], + "id": [0], + } + sample_batch = prepare_sample(sample_batch, cuda_enabled=torch.cuda.is_available(), device=cfg.config.run.device) + return sample_batch + + +def model_inference(cfg, audio, test_prompt, salmonn, wav_processor): # TTFT start_time = time.time() + samples = audio_preprocess(cfg, wav_processor, samples) llm = salmonn.llama_model batch_size = samples["spectrogram"].shape[0] @@ -181,9 +197,13 @@ def main(args): # Load dataset with open("audiolm-trainer/prompts/test_prompt.json", "r") as f: test_prompt = json.load(f) - dataloader = MockDataset.make_mock_dataloader(cfg, sr=16000, audio_length=10) - sample_batch = next(iter(dataloader)) - sample_batch = prepare_sample(sample_batch, cuda_enabled=torch.cuda.is_available()) + # dataloader = MockDataset.make_mock_dataloader(cfg, sr=16000, audio_length=10) + # sample_batch = next(iter(dataloader)) + # sample_batch = prepare_sample(sample_batch, cuda_enabled=torch.cuda.is_available()) + sr=16000 + audio_length=10 + wav_processor = WhisperFeatureExtractor.from_pretrained(cfg.config.model['whisper_path']) + random_sample = torch.randn(sr * audio_length) # Measure memory and latency memory_usages = [] @@ -196,9 +216,10 @@ def main(args): with torch.no_grad(): inference_time, ttft, tpot = model_inference( cfg, - sample_batch, + random_sample, test_prompt, salmonn_preprocessor, + wav_processor ) torch.cuda.synchronize() after_memory_allocated = torch.cuda.max_memory_allocated()