diff --git a/examples/pytorch_nyctaxi.py b/examples/pytorch_nyctaxi.py index b3ef39f8..f60cfc3a 100644 --- a/examples/pytorch_nyctaxi.py +++ b/examples/pytorch_nyctaxi.py @@ -72,7 +72,9 @@ def forward(self, x): feature_columns=features, feature_types=torch.float, label_column="fare_amount", label_type=torch.float, batch_size=64, num_epochs=30, - metrics_name = ["MeanAbsoluteError", "MeanSquaredError"]) + metrics_name=["MeanAbsoluteError", "MeanSquaredError"], + use_ipex=False, use_bf16=False, use_amp=False, + use_ccl=False, use_jit_trace=False) # Train the model estimator.fit_on_spark(train_df, test_df) # Get the trained model diff --git a/python/raydp/torch/config.py b/python/raydp/torch/config.py new file mode 100644 index 00000000..f902f919 --- /dev/null +++ b/python/raydp/torch/config.py @@ -0,0 +1,42 @@ +from ray.train.torch.config import _TorchBackend +from ray.train.torch.config import TorchConfig as RayTorchConfig +from ray.train._internal.worker_group import WorkerGroup +from dataclasses import dataclass +from packaging import version +import importlib_metadata + +@dataclass +class TorchConfig(RayTorchConfig): + + @property + def backend_cls(self): + return EnableCCLBackend + +def ccl_import(): + # pylint: disable=import-outside-toplevel + import oneccl_bindings_for_pytorch + +class EnableCCLBackend(_TorchBackend): + + def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): + for i in range(len(worker_group)): + worker_group.execute_single_async(i, ccl_import) + super().on_start(worker_group, backend_config) + +def check_ipex(): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + try: + _torch_version = importlib_metadata.version("torch") + except importlib_metadata.PackageNotFoundError: + return None, None + + try: + _ipex_version = importlib_metadata.version("intel_extension_for_pytorch") + except importlib_metadata.PackageNotFoundError: + return _torch_version, None + + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + return torch_major_and_minor , ipex_major_and_minor diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index fec2a506..470523d4 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -26,6 +26,7 @@ from raydp.torch.torch_metrics import TorchMetric from raydp import stop_spark from raydp.spark import spark_dataframe_to_ray_dataset +from raydp.torch.config import TorchConfig, check_ipex import ray from ray import train @@ -89,6 +90,11 @@ def __init__(self, num_processes_for_data_loader: int = 0, metrics_name: Optional[List[Union[str, Callable]]] = None, metrics_config: Optional[Dict[str,Dict[str, Any]]] = None, + use_ipex: bool = False, + use_bf16: bool = False, + use_amp: bool = False, + use_ccl: bool = False, + use_jit_trace: bool = False, **extra_config): """ :param num_workers: the number of workers to do the distributed training @@ -131,6 +137,12 @@ def __init__(self, :param metrics_config: the optional config for the metrics. Its format is: {"metric_name_1": {"param1": value1, "param2": value2}, "metric_name_2":{}}, where param is the parameter corresponding to a concrete metric class of TorchMetrics. + :param use_ipex: whether to enable ipex optimization + :param use_bf16: whether to cast model parameters to ``torch.bfloat16`` + :param use_amp: whether to enable auto mixed precision + :param use_ccl: whether to use torch_ccl as the backend to initialize default distributed + process group + :param use_jit_trace: whether to use jit.trace to accelerate the model :param extra_config: the extra config will be set to ray.train.torch.TorchTrainer """ self._num_workers = num_workers @@ -149,11 +161,25 @@ def __init__(self, self._shuffle = shuffle self._num_processes_for_data_loader = num_processes_for_data_loader self._metrics = TorchMetric(metrics_name, metrics_config) + self._use_ipex = use_ipex + self._use_bf16 = use_bf16 + self._use_amp = use_amp + self._use_ccl = use_ccl + self._use_jit_trace = use_jit_trace self._extra_config = extra_config if self._num_processes_for_data_loader > 0: raise TypeError("multiple processes for data loader has not supported") + if self._use_ipex: + torch_version, ipex_version = check_ipex() + assert torch_version is not None, "Pytorch is not found. Please install Pytorch." + assert ipex_version is not None, "Intel Extension for PyTorch is not found. "\ + "Please install Intel Extension for PyTorch." + assert torch_version==ipex_version, "Intel Extension for PyTorch {ipex} needs to "\ + "work with PyTorch {ipex}.*, but PyTorch {torch} is found. Please switch to "\ + "the matching version.".format(ipex=ipex_version, torch=torch_version) + self._trainer: TorchTrainer = None self._check() @@ -211,6 +237,18 @@ def train_func(config): # get merics metrics = config["metrics"] + # ipex optimize + use_ipex = config["use_ipex"] + use_bf16 = config["use_bf16"] + use_amp = config["use_amp"] + use_jit_trace = config["use_jit_trace"] + if use_ipex: + # pylint: disable=import-outside-toplevel + import intel_extension_for_pytorch as ipex + model = model.to(memory_format=torch.channels_last) + dtype = torch.bfloat16 if use_bf16 else None + model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype) + # create dataset train_data_shard = session.get_dataset_shard("train") train_dataset = train_data_shard.to_torch(feature_columns=config["feature_columns"], @@ -233,11 +271,14 @@ def train_func(config): loss_results = [] for epoch in range(config["num_epochs"]): train_res, train_loss = TorchEstimator.train_epoch(train_dataset, model, loss, - optimizer, metrics, lr_scheduler) + optimizer, metrics, use_amp, + use_bf16, use_jit_trace, + lr_scheduler) session.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss)) if config["evaluate"]: - eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset, - model, loss, metrics) + eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset, model, + loss, metrics, use_amp, + use_bf16, use_jit_trace) session.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss)) loss_results.append(evaluate_loss) if hasattr(model, "module"): @@ -250,13 +291,14 @@ def train_func(config): })) @staticmethod - def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None): + def train_epoch(dataset, model, criterion, optimizer, metrics, use_amp, use_bf16, use_jit_trace, + scheduler=None): model.train() train_loss, data_size, batch_idx = 0, 0, 0 for batch_idx, (inputs, targets) in enumerate(dataset): # Compute prediction error - outputs = model(inputs) - loss = criterion(outputs, targets) + outputs, loss = TorchEstimator.train_batch(batch_idx, model, inputs, targets, criterion, + use_amp, use_bf16, use_jit_trace) train_loss += loss.item() metrics.update(outputs, targets) data_size += targets.size(0) @@ -273,14 +315,16 @@ def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None): return train_res, train_loss @staticmethod - def evaluate_epoch(dataset, model, criterion, metrics): + def evaluate_epoch(dataset, model, criterion, metrics, use_amp, use_bf16, use_jit_trace): model.eval() test_loss, data_size, batch_idx = 0, 0, 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(dataset): # Compute prediction error - outputs = model(inputs) - test_loss += criterion(outputs, targets).item() + outputs, loss = TorchEstimator.train_batch(batch_idx, model, inputs, targets, + criterion, use_amp, use_bf16, + use_jit_trace, is_eval=True) + test_loss += loss.item() metrics.update(outputs, targets) data_size += targets.size(0) @@ -289,6 +333,26 @@ def evaluate_epoch(dataset, model, criterion, metrics): metrics.reset() return eval_res, test_loss + @staticmethod + def train_batch(batch_idx, model, inputs, targets, criterion, use_amp, use_bf16, use_jit_trace, + is_eval=False): + if use_amp and use_bf16: + with torch.cpu.amp.autocast(): + if use_jit_trace and batch_idx==0: + model = torch.jit.trace(model, inputs) + if is_eval: + model = torch.jit.freeze(model) + outputs = model(inputs) + loss = criterion(outputs, targets) + else: + if use_jit_trace and batch_idx==0: + model = torch.jit.trace(model, inputs) + if is_eval: + model = torch.jit.freeze(model) + outputs = model(inputs) + loss = criterion(outputs, targets) + return outputs, loss + def fit(self, train_ds: Dataset, evaluate_ds: Optional[Dataset] = None, @@ -306,7 +370,11 @@ def fit(self, "num_epochs": self._num_epochs, "drop_last": self._drop_last, "evaluate": True, - "metrics": self._metrics + "metrics": self._metrics, + "use_ipex": self._use_ipex, + "use_bf16": self._use_bf16, + "use_amp": self._use_amp, + "use_jit_trace": self._use_jit_trace } scaling_config = ScalingConfig(num_workers=self._num_workers, resources_per_worker=self._resources_per_worker) @@ -320,10 +388,15 @@ def fit(self, train_loop_config["evaluate"] = False else: datasets["evaluate"] = evaluate_ds + if self._use_ccl: + torch_config = TorchConfig(backend="ccl") + else: + torch_config = None self._trainer = TorchTrainer(TorchEstimator.train_func, train_loop_config=train_loop_config, scaling_config=scaling_config, run_config=run_config, + torch_config=torch_config, datasets=datasets) result = self._trainer.fit()