Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions docs/model-dev-guide/model-usage/batch-process-api-ug.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
:orphan:

.. _torch_batch_processing_ug:

############################
Torch Batch Processing API
############################

.. meta::
:description: Learn how to use the Torch Batch Processing API.

In this guide, you'll learn about the :ref:`torch_batch_process_api_ref` and how to perform batch
inference (also known as offline inference).

+---------------------------------------------------------------------+
| Visit the API reference |
+=====================================================================+
| :ref:`torch_batch_process_api_ref` |
+---------------------------------------------------------------------+

.. caution::

This is an experimental API and may change at any time.

**********
Overview
**********

The Torch Batch Processing API takes in (1) a dataset and (2) a user-defined processor class and
runs distributed data processing.

With this API, you can perform the following tasks:

- shard a dataset by number of workers available
- apply user-defined logic to each batch of data
- handle synchronization between workers
- track job progress to enable preemption and resumption of trial

This is a flexible API that can be used for many different tasks, including batch (offline)
inference.

*******
Usage
*******

The main arguments to torch_batch_process are processor class and dataset.

.. code:: python

torch_batch_process(
batch_processor_cls=MyProcessor
dataset=dataset
)

``TorchBatchProcessorContext``
==============================

:class:`~determined.pytorch.experimental.TorchBatchProcessorContext`
should be a subclass of :class:`~determined.pytorch.experimental.TorchBatchProcessor`.
The two functions you must implement are the
``__init__`` and ``process_batch``. The other lifecycle functions are optional.

``TorchBatchProcessor``
=======================

During ``__init__`` of :class:`~determined.pytorch.experimental.TorchBatchProcessor`,
we pass in a :class:`~determined.pytorch.experimental.TorchBatchProcessorContext` object,
which contains useful methods that can be used within the ``TorchBatchProcessor`` class.

******************************************
How To Perform Batch (Offline) Inference
******************************************

In this section, we'll learn how to perform batch inference using the Torch Batch Processing API.
For more information about this use case or to obtain the tutorial files, visit this [placeholder
URL].

Step 1: Define a InferenceProcessor
===================================

The first step is to define an InferenceProcessor. You should initialize your model in the ``__init__``
function of the InferenceProcessor.

.. code:: python

"""
Define custom processor class
"""
class InferenceProcessor(TorchBatchProcessor):
def __init__(self, context):
self.context = context
self.model = context.prepare_model_for_inference(get_model())
self.output = []
self.last_index = 0

def process_batch(self, batch, batch_idx) -> None:
model_input = batch[0]
model_input = self.context.to_device(model_input)

with torch.no_grad():
with self.profiler as p:
pred = self.model(model_input)
p.step()
output = {"predictions": pred, "input": batch}
self.output.append(output)

self.last_index = batch_idx

def on_checkpoint_start(self):
# During checkpoint, we persist prediction result
if len(self.output) == 0:
return
file_name = f"prediction_output_{self.last_index}"
with self.context.upload_path() as path:
file_path = pathlib.Path(path, file_name)
torch.save(self.output, file_path)

self.output = []

Step 2: Initialize the Dataset
==============================

Initialize the dataset you want to process.

.. code:: python

"""
Initialize dataset
"""
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
with filelock.FileLock(os.path.join("/tmp", "inference.lock")):
inference_data = tv.datasets.CIFAR10(
root="/data", train=False, download=True, transform=transform
)

Step 3: Pass the InferenceProcessor Class and Dataset
=====================================================

Finally, pass the InferenceProcessor class and the dataset to ``torch_batch_process``.

.. code:: python

"""
Pass processor class and dataset to torch_batch_process
"""
torch_batch_process(
InferenceProcessor,
dataset,
batch_size=64,
checkpoint_interval=10
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
:orphan:

.. _torch_batch_process_api_ref:

#######################################
``Torch Batch Process`` API Reference
#######################################

.. meta::
:description: Familiarize yourself with the Torch Batch Process API.

+--------------------------------------------+
| User Guide |
+============================================+
| :ref:`torch_batch_processing_ug` |
+--------------------------------------------+

.. caution::

This is an experimental API and may change at any time.

The main arguments to torch_batch_process is processor class and dataset.

.. code:: python

torch_batch_process(
batch_processor_cls=MyProcessor
dataset=dataset
)

***************************************************
``determined.pytorch.TorchBatchProcessorContext``
***************************************************

.. autoclass:: determined.pytorch.experimental.TorchBatchProcessorContext
:members:
:member-order: bysource

********************************************
``determined.pytorch.TorchBatchProcessor``
********************************************

.. autoclass:: determined.pytorch.experimental.TorchBatchProcessor
:members:
:member-order: bysource