Skip to content

Extensive memory consumption in infer.py #2

@Jamiroquai88

Description

@Jamiroquai88

Hey, I have been using the codebase for quite some time and I noticed extensive memory consumption of infer.py script. In my case it consumes around 120GB of RAM, which is a huge deal - my machine has 2T of RAM but I would need to run more of them in parallel.

I tried a simple RAM profiler on per-line basis and this is the output:


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   207  411.480 MiB  411.480 MiB           1   @profile
   208                                         def main():
   209  411.594 MiB    0.113 MiB           1       args = parse_arguments()
   210
   211                                             # For reproducibility
   212  411.602 MiB    0.008 MiB           1       torch.manual_seed(args.seed)
   213  411.602 MiB    0.000 MiB           1       torch.cuda.manual_seed(args.seed)
   214  411.602 MiB    0.000 MiB           1       torch.cuda.manual_seed_all(args.seed)  # if you are using multi-GPU.
   215  411.602 MiB    0.000 MiB           1       np.random.seed(args.seed)  # Numpy module.
   216  411.602 MiB    0.000 MiB           1       random.seed(args.seed)  # Python random module.
   217  411.602 MiB    0.000 MiB           1       torch.manual_seed(args.seed)
   218  411.602 MiB    0.000 MiB           1       torch.backends.cudnn.benchmark = False
   219  411.602 MiB    0.000 MiB           1       torch.backends.cudnn.deterministic = True
   220  411.602 MiB    0.000 MiB           1       os.environ['PYTHONHASHSEED'] = str(args.seed)
   221
   222  411.602 MiB    0.000 MiB           1       logging.info(args)
   223
   224  483.664 MiB  483.664 MiB           1       infer_loader = get_infer_dataloader(args)
   225
   226  483.664 MiB    0.000 MiB           1       if args.gpu >= 1:
   227                                                 gpuid = use_single_gpu(args.gpu)
   228                                                 logging.info('GPU device {} is used'.format(gpuid))
   229                                                 args.device = torch.device("cuda")
   230                                             else:
   231  483.664 MiB    0.000 MiB           1           gpuid = -1
   232  483.664 MiB    0.000 MiB           1           args.device = torch.device("cpu")
   233
   234  483.664 MiB    0.000 MiB           1       assert args.estimate_spk_qty_thr != -1 or \
   235                                                 args.estimate_spk_qty != -1, \
   236                                                 ("Either 'estimate_spk_qty_thr' or 'estimate_spk_qty' "
   237                                                  "arguments have to be defined.")
   238  483.664 MiB    0.000 MiB           1       if args.estimate_spk_qty != -1:
   239  483.664 MiB    0.000 MiB           3           out_dir = join(args.rttms_dir, f"spkqty{args.estimate_spk_qty}_\
   240  483.664 MiB    0.000 MiB           2               thr{args.threshold}_median{args.median_window_length}")
   241                                             elif args.estimate_spk_qty_thr != -1:
   242                                                 out_dir = join(args.rttms_dir, f"spkqtythr{args.estimate_spk_qty_thr}_\
   243                                                     thr{args.threshold}_median{args.median_window_length}")
   244
   245  510.508 MiB   26.844 MiB           1       model = get_model(args)
   246
   247  812.148 MiB  301.641 MiB           2       model = average_checkpoints(
   248  510.508 MiB    0.000 MiB           1           args.device, model, args.models_path, args.epochs)
   249  812.148 MiB    0.000 MiB           1       model.eval()
   250
   251  812.148 MiB    0.000 MiB           2       out_dir = join(
   252  812.148 MiB    0.000 MiB           1           args.rttms_dir,
   253  812.148 MiB    0.000 MiB           1           f"epochs{args.epochs}",
   254  812.148 MiB    0.000 MiB           1           f"timeshuffle{args.time_shuffle}",
   255  812.148 MiB    0.000 MiB           1           (f"spk_qty{args.estimate_spk_qty}_"
   256                                                     f"spk_qty_thr{args.estimate_spk_qty_thr}"),
   257  812.148 MiB    0.000 MiB           1           f"detection_thr{args.threshold}",
   258  812.148 MiB    0.000 MiB           1           f"median{args.median_window_length}",
   259  812.148 MiB    0.000 MiB           1           "rttms"
   260                                             )
   261  812.297 MiB    0.148 MiB           1       Path(out_dir).mkdir(parents=True, exist_ok=True)
   262
   263 34718.992 MiB -32000.816 MiB           4       for i, batch in enumerate(infer_loader):
   264 34689.719 MiB   48.301 MiB           3           input = torch.stack(batch['xs']).to(args.device)
   265 34689.719 MiB    0.000 MiB           3           name = batch['names'][0]
   266 34689.719 MiB    0.000 MiB           3           with torch.no_grad():
   267 34749.977 MiB 1166.988 MiB           3               y_pred = model.estimate_sequential(input, args)[0]
   268 34749.977 MiB 13027.621 MiB           6           post_y = postprocess_output(
   269 34749.977 MiB -32767.594 MiB           3               y_pred, args.subsampling,
   270 34749.977 MiB -32767.594 MiB           3               args.threshold, args.median_window_length)
   271 34749.977 MiB -32767.594 MiB           3           rttm_filename = join(out_dir, f"{name}.rttm")
   272 34749.977 MiB -32767.594 MiB           3           with open(rttm_filename, 'w') as rttm_file:
   273 34749.977 MiB 45795.215 MiB           3               hard_labels_to_rttm(post_y, name, rttm_file)

So it looks like the larger consumption is coming out of infer_loader.
Any ideas on how we could improve this?

I am running on pretty long audios ~1 hour.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions