-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
Hey, I have been using the codebase for quite some time and I noticed extensive memory consumption of infer.py script. In my case it consumes around 120GB of RAM, which is a huge deal - my machine has 2T of RAM but I would need to run more of them in parallel.
I tried a simple RAM profiler on per-line basis and this is the output:
Line # Mem usage Increment Occurrences Line Contents
=============================================================
207 411.480 MiB 411.480 MiB 1 @profile
208 def main():
209 411.594 MiB 0.113 MiB 1 args = parse_arguments()
210
211 # For reproducibility
212 411.602 MiB 0.008 MiB 1 torch.manual_seed(args.seed)
213 411.602 MiB 0.000 MiB 1 torch.cuda.manual_seed(args.seed)
214 411.602 MiB 0.000 MiB 1 torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
215 411.602 MiB 0.000 MiB 1 np.random.seed(args.seed) # Numpy module.
216 411.602 MiB 0.000 MiB 1 random.seed(args.seed) # Python random module.
217 411.602 MiB 0.000 MiB 1 torch.manual_seed(args.seed)
218 411.602 MiB 0.000 MiB 1 torch.backends.cudnn.benchmark = False
219 411.602 MiB 0.000 MiB 1 torch.backends.cudnn.deterministic = True
220 411.602 MiB 0.000 MiB 1 os.environ['PYTHONHASHSEED'] = str(args.seed)
221
222 411.602 MiB 0.000 MiB 1 logging.info(args)
223
224 483.664 MiB 483.664 MiB 1 infer_loader = get_infer_dataloader(args)
225
226 483.664 MiB 0.000 MiB 1 if args.gpu >= 1:
227 gpuid = use_single_gpu(args.gpu)
228 logging.info('GPU device {} is used'.format(gpuid))
229 args.device = torch.device("cuda")
230 else:
231 483.664 MiB 0.000 MiB 1 gpuid = -1
232 483.664 MiB 0.000 MiB 1 args.device = torch.device("cpu")
233
234 483.664 MiB 0.000 MiB 1 assert args.estimate_spk_qty_thr != -1 or \
235 args.estimate_spk_qty != -1, \
236 ("Either 'estimate_spk_qty_thr' or 'estimate_spk_qty' "
237 "arguments have to be defined.")
238 483.664 MiB 0.000 MiB 1 if args.estimate_spk_qty != -1:
239 483.664 MiB 0.000 MiB 3 out_dir = join(args.rttms_dir, f"spkqty{args.estimate_spk_qty}_\
240 483.664 MiB 0.000 MiB 2 thr{args.threshold}_median{args.median_window_length}")
241 elif args.estimate_spk_qty_thr != -1:
242 out_dir = join(args.rttms_dir, f"spkqtythr{args.estimate_spk_qty_thr}_\
243 thr{args.threshold}_median{args.median_window_length}")
244
245 510.508 MiB 26.844 MiB 1 model = get_model(args)
246
247 812.148 MiB 301.641 MiB 2 model = average_checkpoints(
248 510.508 MiB 0.000 MiB 1 args.device, model, args.models_path, args.epochs)
249 812.148 MiB 0.000 MiB 1 model.eval()
250
251 812.148 MiB 0.000 MiB 2 out_dir = join(
252 812.148 MiB 0.000 MiB 1 args.rttms_dir,
253 812.148 MiB 0.000 MiB 1 f"epochs{args.epochs}",
254 812.148 MiB 0.000 MiB 1 f"timeshuffle{args.time_shuffle}",
255 812.148 MiB 0.000 MiB 1 (f"spk_qty{args.estimate_spk_qty}_"
256 f"spk_qty_thr{args.estimate_spk_qty_thr}"),
257 812.148 MiB 0.000 MiB 1 f"detection_thr{args.threshold}",
258 812.148 MiB 0.000 MiB 1 f"median{args.median_window_length}",
259 812.148 MiB 0.000 MiB 1 "rttms"
260 )
261 812.297 MiB 0.148 MiB 1 Path(out_dir).mkdir(parents=True, exist_ok=True)
262
263 34718.992 MiB -32000.816 MiB 4 for i, batch in enumerate(infer_loader):
264 34689.719 MiB 48.301 MiB 3 input = torch.stack(batch['xs']).to(args.device)
265 34689.719 MiB 0.000 MiB 3 name = batch['names'][0]
266 34689.719 MiB 0.000 MiB 3 with torch.no_grad():
267 34749.977 MiB 1166.988 MiB 3 y_pred = model.estimate_sequential(input, args)[0]
268 34749.977 MiB 13027.621 MiB 6 post_y = postprocess_output(
269 34749.977 MiB -32767.594 MiB 3 y_pred, args.subsampling,
270 34749.977 MiB -32767.594 MiB 3 args.threshold, args.median_window_length)
271 34749.977 MiB -32767.594 MiB 3 rttm_filename = join(out_dir, f"{name}.rttm")
272 34749.977 MiB -32767.594 MiB 3 with open(rttm_filename, 'w') as rttm_file:
273 34749.977 MiB 45795.215 MiB 3 hard_labels_to_rttm(post_y, name, rttm_file)
So it looks like the larger consumption is coming out of infer_loader.
Any ideas on how we could improve this?
I am running on pretty long audios ~1 hour.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels