Skip to content

Sampler lacks support for DataLoader with multiple workers #103

@wiwu2390

Description

@wiwu2390

For example, when I call estimate_learning_coeff_with_summary(..., loader=DataLoader(..., num_workers=8)), I get the following traceback:

File ~/wilson/Singfluence/.venv/lib/python3.10/site-packages/devinterp/backends/default/slt/sampler.py:343, in sample(model, loader, callbacks, evaluate, sampling_method, optimizer_kwargs, num_draws, num_chains, num_burnin_steps, num_steps_bw_draws, init_loss, grad_accum_steps, cores, seed, device, verbose, optimize_over_per_model_param, gpu_idxs, batch_size, use_amp, **kwargs)
    337         return model(data)
    339 validate_callbacks(callbacks)
    341 shared_kwargs = dict(
    342     ref_model=model,
--> 343     loader=cloudpickle.dumps(loader),
    344     evaluate=cloudpickle.dumps(evaluate),
    345     num_draws=num_draws,
    346     num_burnin_steps=num_burnin_steps,
    347     num_steps_bw_draws=num_steps_bw_draws,
    348     init_loss=init_loss,
    349     grad_accum_steps=grad_accum_steps,
    350     sampling_method=sampling_method,
    351     optimizer_kwargs=optimizer_kwargs,
    352     verbose=verbose,
    353     optimize_over_per_model_param=optimize_over_per_model_param,
    354     use_amp=use_amp,
    355 )
    357 if cores > 1:
    358     # mp.spawn(
    359     #     _sample_single_chain_mp,
   (...)
    363     #     start_method="spawn",
    364     # )
    366     if gpu_idxs is not None:

File ~/wilson/Singfluence/.venv/lib/python3.10/site-packages/cloudpickle/cloudpickle.py:1537, in dumps(obj, protocol, buffer_callback)
   1535 with io.BytesIO() as file:
   1536     cp = Pickler(file, protocol=protocol, buffer_callback=buffer_callback)
-> 1537     cp.dump(obj)
   1538     return file.getvalue()

File ~/wilson/Singfluence/.venv/lib/python3.10/site-packages/cloudpickle/cloudpickle.py:1303, in Pickler.dump(self, obj)
   1301 def dump(self, obj):
   1302     try:
-> 1303         return super().dump(obj)
   1304     except RuntimeError as e:
   1305         if len(e.args) > 0 and "recursion" in e.args[0]:

File ~/wilson/Singfluence/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py:737, in _BaseDataLoaderIter.__getstate__(self)
    731 def __getstate__(self):
    732     # TODO: add limited pickling support for sharing an iterator
    733     # across multiple threads for HOGWILD.
    734     # Probably the best way to do this is by moving the sample pushing
    735     # to a separate thread and then just sharing the data queue
    736     # but signalling the end is tricky without a non-blocking API
--> 737     raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

NotImplementedError: ('{} cannot be pickled', '_MultiProcessingDataLoaderIter')

For some workloads, the data loading is a non-negligible bottleneck, so it'd be nice to support multiple workers.
Perhaps the easiest way to do this is to support passing in all the DataLoader kwargs and construct it inside each worker? Instead of pickling the loader itself.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions