Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.
8 changes: 6 additions & 2 deletions distml/operator/base_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Abstract class for framework-specific training operators."""
from abc import ABCMeta
from abc import abstractmethod
from typing import Optional


class TrainingOperator(metaclass=ABCMeta):
Expand Down Expand Up @@ -90,7 +91,7 @@ def load_custom_states(self, states, *args, **kwargs):
pass

@abstractmethod
def save_states(self, checkpoint):
def save_states(self, checkpoint: str):
"""Save the states to a file path.

This function shall be instantiated in framework-specific operator
Expand All @@ -104,7 +105,10 @@ def get_states(self):
raise NotImplementedError()

@abstractmethod
def load_states(self, checkpoint):
def load_states(self,
states=None,
checkpoint: Optional[str] = None,
keys: Optional[bool] = None):
"""Load the states from a file path.

This functions shall be instantiated in framework-specific operators
Expand Down
200 changes: 170 additions & 30 deletions distml/operator/jax_operator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import os
import pickle
import warnings

from typing import Any, Mapping, Optional, List, Dict

import numpy as np
import cupy as cp

import jax
from jax import value_and_grad
import jax.numpy as jnp
Expand All @@ -13,7 +20,7 @@


class JAXTrainingOperator(TrainingOperator):
def __init__(self, operator_config):
def __init__(self, *, operator_config: Optional[Mapping[str, Any]]):
super(JAXTrainingOperator, self).__init__(operator_config)
# Should be set by users in the `register` function.
# model methods
Expand All @@ -26,11 +33,14 @@ def __init__(self, operator_config):
self.get_params = None

self.criterion = None
self.lr_scheduler = None

# Data loaders for training and validation, registered by users.
self._train_loader = None
self._validation_loader = None

self._custom_states = None

self.setup(operator_config)

if hasattr(operator_config, "jit_mode"):
Expand Down Expand Up @@ -63,7 +73,7 @@ def setup(self, *args, **kwargs):
raise NotImplementedError("Please override this function to register "
"your model, optimizer, and criterion.")

def register(self, *, model, optimizer, criterion, jit_mode=False):
def register(self, *, model, optimizer, criterion, jit_mode: bool = False):
"""Register a few critical information about the model to operator.

Args:
Expand Down Expand Up @@ -93,7 +103,7 @@ def register(self, *, model, optimizer, criterion, jit_mode=False):
"'opt_init', 'opt_update' and 'get_params'."
"Got: {} {}".format(type(optimizer), len(optimizer)))

if not hasattr(criterion, "__call__"):
if not callable(criterion):
raise RuntimeError(
"The `criterion` must be callable function that "
"feed logits and target, return the loss value. "
Expand All @@ -113,12 +123,12 @@ def _register_model(self, model):
"`opt_states` return from optimizer `opt_init`. "
"Got: {}".format(type(model[0])))

if not hasattr(model[1], "__call__"):
if not callable(model[1]):
raise RuntimeError("The second elemente of `model` must be the "
"`init_fun` return from model. "
"Got: {}".format(type(model[1])))

if not hasattr(model[2], "__call__"):
if not callable(model[2]):
raise RuntimeError("The third elemente of `model` must be the "
"`predict_fun` return from model. "
"Got: {}".format(type(model[2])))
Expand All @@ -129,18 +139,18 @@ def _register_model(self, model):

def _register_optimizer(self, optimizer):
"""register optimizer components."""
if not hasattr(optimizer[0], "__call__"):
if not callable(optimizer[0]):
raise RuntimeError("The fist elemente of `optimizer` must be the "
"`opt_init` return from optimizer. "
"Got: {}".format(type(optimizer[1])))

if not hasattr(optimizer[1], "__call__"):
if not callable(optimizer[1]):
raise RuntimeError(
"The second elemente of `optimizer` must be the "
"`opt_update` return from optimizer. "
"Got: {}".format(type(optimizer[1])))

if not hasattr(optimizer[2], "__call__"):
if not callable(optimizer[2]):
raise RuntimeError("The third elemente of `optimizer` must be the "
"`get_params` return from optimizer. "
"Got: {}".format(type(optimizer[2])))
Expand Down Expand Up @@ -264,15 +274,15 @@ def validate_batch(self, batch):
targets_class = jnp.argmax(targets, axis=1)

acc = jnp.mean(prediction_class == targets_class)
samples_num = targets.shape[0]
num_sample = targets.shape[0]

return {
"val_loss": loss.item(),
"val_accuracy": acc.item(),
"samples_num": samples_num
"num_sample": num_sample
}

def get_parameters(self, cpu):
def get_parameters(self, cpu: bool) -> List:
"""get the flatten parameters."""
params = self.get_params(self.opt_state)
flatten_params, tree = tree_flatten(params)
Expand All @@ -281,9 +291,11 @@ def get_parameters(self, cpu):

if cpu:
flatten_params = list(map(np.asarray, flatten_params))
else:
flatten_params = list(map(jnp.asarray, flatten_params))
return flatten_params

def get_named_parameters(self, cpu):
def get_named_parameters(self, cpu: bool) -> Dict:
"""Get the named parameters.

In jax, we need to construct a dict to contain the parameters.
Expand All @@ -296,6 +308,7 @@ def get_named_parameters(self, cpu):
}
else:
dict_params = {f"{idx}": p for idx, p in enumerate(params)}

return dict_params

# TODO(HUI): used in load states or load parameters
Expand All @@ -309,6 +322,9 @@ def set_parameters(self, new_params):
"""
assert isinstance(new_params, dict)

# make sure all params in GPU. Should be controlled of use_gpu.
new_params = {k: jax.device_put(v) for k, v in new_params.items()}

keys, new_params = unzip2(
sorted(new_params.items(), key=lambda d: int(d[0])))
self.preset_keys = keys
Expand All @@ -334,7 +350,7 @@ def update(param, state):
zip(subtrees, new_subtrees)):
if new_subtree != subtree:
msg = (
"input structur did not match the save params struture. "
"input structure did not match the save params structure. "
"input {} and output {}.")
raise TypeError(msg.format(subtree, new_subtree))

Expand All @@ -346,29 +362,153 @@ def reset_optimizer_for_params(self, params):
"Got {}".format(type(params)))

keys, params = unzip2(sorted(params.items(), key=lambda d: int(d[0])))

self.preset_keys = keys # The keys to index the params.
self.tree = tree_structure(params)
self.opt_state = self.opt_init(params)

def ones(self, shape, cpu: bool = True):
if cpu:
return np.ones(shape)
else:
return jnp.ones(shape)

def zeros(self, shape, cpu: bool = True):
if cpu:
return np.zeros(shape)
else:
return jnp.zeros(shape)

def ones_like(self, x, cpu: bool = True):
if cpu:
return np.ones_like(x)
else:
return jnp.ones_like(x)

def zeros_like(self, x, cpu: bool = True):
if cpu:
return np.zeros_like(x)
else:
return jnp.zeros_like(x)

def numel(self, v):
return np.size(v)

def asarray(self, v):
return jnp.asarray(v)

def clean_redundancy(self):
del self._train_loader
del self._validation_loader
if self._train_loader:
del self._train_loader
self._train_loader = None
if self._validation_loader:
del self._validation_loader
self._validation_loader = None

# TODO(HUI): use pickle to serialize parameters or states and save it.
def save_parameters(self, checkpoint):
raise NotImplementedError(
"save_parameters is not support in jax operator.")
def register_custom_states(self, custom_states):
self._custom_states = custom_states

def load_parameters(self, checkpoint):
raise NotImplementedError(
"load_parameters is not support in jax operator.")
def get_custom_states(self):
return self._custom_states

def save_states(self, checkpoint):
raise NotImplementedError(
"save_states is not support in jax operator.")
def get_states(self) -> Dict:
"""Return the states of this training operator."""

def get_states(self):
raise NotImplementedError("get_states is not support in jax operator.")
states_flat, tree, subtrees = self.opt_state

states_unflat = map(tree_unflatten, subtrees, states_flat)

states_unflat_dict = {
str(idx): value
for idx, value in enumerate(states_unflat)
}

def load_states(self, checkpoint):
raise NotImplementedError(
"load_states is not support in jax operator.")
states = {
"opt_state": states_unflat_dict,
}

if self._custom_states:
states.update({"custom": self.get_custom_states()})

if self.lr_scheduler and hasattr(self.lr_scheduler,
"get_state_dict()"):
states.update({"lr_scheduler": self.lr_scheduler.get_state_dict()})

return states

def save_states(self, checkpoint: str):
states = self.get_states()
with open(checkpoint, "wb") as f:
pickle.dump(states, f)

def load_states(self,
states=None,
checkpoint: Optional[str] = None,
keys: Optional[bool] = None):
if checkpoint:
assert ".pkl" in checkpoint, \
"checkpoint should be a .pkl file. Got {}".format(checkpoint)
if not os.path.exists(checkpoint):
raise RuntimeError("Checkpoint file doesn't exists.")
with open(checkpoint, "rb") as f:
states = pickle.load(f)

if states:
new_opt_states = states.get("opt_state", None)
custom_states = states.get("custom_states", None)
lr_scheduler_states = states.get("lr_scheduler", None)

if not new_opt_states:
raise RuntimeError("subtrees of new params is empty.")

assert isinstance(new_opt_states, dict)

if not keys:
keys = tuple([
str(idx)
for idx in range(len(self.get_parameters(cpu=False)))
])
else:
# construct_opt_states_dict = OrderedDict()
construct_opt_states_dict = dict()
for key in keys:
construct_opt_states_dict[key] = new_opt_states[key]
new_opt_states = construct_opt_states_dict

new_keys, new_opt_states = unzip2(
sorted(new_opt_states.items(), key=lambda d: int(d[0])))

keys = tuple(keys)
new_keys = tuple(new_keys)
assert keys == new_keys, \
"checkpoint key doesn't match the model params."

states_flat, tree, subtrees = self.opt_state
states_flat_2, subtrees_2 = unzip2(
map(tree_flatten, new_opt_states))

if not subtrees_2:
raise RuntimeError("subtrees of new params is empty.")
for idx, (subtree, subtree_2) in enumerate(
zip(subtrees, subtrees_2)):
if subtree_2 != subtree:
msg = ("input structure did not match the save params "
"structure. input {} and output {}.")
raise TypeError(msg.format(subtree, subtree_2))

self.opt_state = OptimizerState(states_flat_2, tree, subtrees_2)

if custom_states:
self._custom_states.update(custom_states)

if lr_scheduler_states:
if hasattr(self.lr_scheduler, "set_states_dict"):
self.lr_scheduler.set_states_dict(lr_scheduler_states)
else:
warnings.warn(
"lr scheduler must have `set_states_dict` method"
" to support loading lr scheduler states.")
else:
raise RuntimeError("This checkpoint is empty."
"Got checkpoint {}, states {}".format(
checkpoint, states))
Loading