From 243f33931cbbce2e6097576cde94a86c28b97322 Mon Sep 17 00:00:00 2001 From: KepingYan Date: Wed, 8 Feb 2023 13:29:50 +0800 Subject: [PATCH 1/7] enable ccl in TorchEstimator --- examples/pytorch_nyctaxi.py | 4 +++- python/raydp/torch/estimator.py | 18 +++++++++++++++++- python/raydp/torch/torch_ccl_config.py | 22 ++++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 python/raydp/torch/torch_ccl_config.py diff --git a/examples/pytorch_nyctaxi.py b/examples/pytorch_nyctaxi.py index b3ef39f8..c589dd4a 100644 --- a/examples/pytorch_nyctaxi.py +++ b/examples/pytorch_nyctaxi.py @@ -72,7 +72,9 @@ def forward(self, x): feature_columns=features, feature_types=torch.float, label_column="fare_amount", label_type=torch.float, batch_size=64, num_epochs=30, - metrics_name = ["MeanAbsoluteError", "MeanSquaredError"]) + metrics_name=["MeanAbsoluteError", "MeanSquaredError"], + use_ipex=False, + use_ccl=False) # Train the model estimator.fit_on_spark(train_df, test_df) # Get the trained model diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index fec2a506..66193146 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -34,6 +34,7 @@ from ray.air.checkpoint import Checkpoint from ray.air import session from ray.data.dataset import Dataset +from raydp.torch.torch_ccl_config import CCLConfig class TorchEstimator(EstimatorInterface, SparkEstimatorInterface): """ @@ -89,6 +90,8 @@ def __init__(self, num_processes_for_data_loader: int = 0, metrics_name: Optional[List[Union[str, Callable]]] = None, metrics_config: Optional[Dict[str,Dict[str, Any]]] = None, + use_ipex: bool = False, + use_ccl: bool = False, **extra_config): """ :param num_workers: the number of workers to do the distributed training @@ -149,6 +152,8 @@ def __init__(self, self._shuffle = shuffle self._num_processes_for_data_loader = num_processes_for_data_loader self._metrics = TorchMetric(metrics_name, metrics_config) + self._use_ipex = use_ipex + self._use_ccl = use_ccl self._extra_config = extra_config if self._num_processes_for_data_loader > 0: @@ -210,6 +215,10 @@ def train_func(config): # get merics metrics = config["metrics"] + if config["use_ipex"]: + import intel_extension_for_pytorch as ipex + model = model.to(memory_format=torch.channels_last) + model, optimizer = ipex.optimize(model, optimizer=optimizer) # create dataset train_data_shard = session.get_dataset_shard("train") @@ -306,7 +315,8 @@ def fit(self, "num_epochs": self._num_epochs, "drop_last": self._drop_last, "evaluate": True, - "metrics": self._metrics + "metrics": self._metrics, + "use_ipex": self._use_ipex, } scaling_config = ScalingConfig(num_workers=self._num_workers, resources_per_worker=self._resources_per_worker) @@ -320,10 +330,16 @@ def fit(self, train_loop_config["evaluate"] = False else: datasets["evaluate"] = evaluate_ds + + if self._use_ccl: + torch_config = CCLConfig(backend='ccl') + else: + torch_config = None self._trainer = TorchTrainer(TorchEstimator.train_func, train_loop_config=train_loop_config, scaling_config=scaling_config, run_config=run_config, + torch_config=torch_config, datasets=datasets) result = self._trainer.fit() diff --git a/python/raydp/torch/torch_ccl_config.py b/python/raydp/torch/torch_ccl_config.py new file mode 100644 index 00000000..17804b33 --- /dev/null +++ b/python/raydp/torch/torch_ccl_config.py @@ -0,0 +1,22 @@ +from ray.train.torch.config import _TorchBackend +from ray.train.torch.config import TorchConfig +from dataclasses import dataclass +from ray.train._internal.worker_group import WorkerGroup + + +@dataclass +class CCLConfig(TorchConfig): + + @property + def backend_cls(self): + return EnableCCLBackend + +def ccl_import(): + import oneccl_bindings_for_pytorch + +class EnableCCLBackend(_TorchBackend): + + def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): + for i in range(len(worker_group)): + worker_group.execute_single_async(i, ccl_import) + super().on_start(worker_group, backend_config) \ No newline at end of file From ad69510f01a4489368b2a4dd0ab359c06a8840e9 Mon Sep 17 00:00:00 2001 From: KepingYan Date: Thu, 9 Feb 2023 16:31:55 +0800 Subject: [PATCH 2/7] add amp optimize --- examples/pytorch_nyctaxi.py | 3 +-- python/raydp/torch/estimator.py | 44 +++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/examples/pytorch_nyctaxi.py b/examples/pytorch_nyctaxi.py index c589dd4a..0e1b3a8d 100644 --- a/examples/pytorch_nyctaxi.py +++ b/examples/pytorch_nyctaxi.py @@ -73,8 +73,7 @@ def forward(self, x): label_column="fare_amount", label_type=torch.float, batch_size=64, num_epochs=30, metrics_name=["MeanAbsoluteError", "MeanSquaredError"], - use_ipex=False, - use_ccl=False) + use_ipex=False, use_bf16=False, use_amp=False, use_ccl=False) # Train the model estimator.fit_on_spark(train_df, test_df) # Get the trained model diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index 66193146..6f6237b1 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -91,6 +91,8 @@ def __init__(self, metrics_name: Optional[List[Union[str, Callable]]] = None, metrics_config: Optional[Dict[str,Dict[str, Any]]] = None, use_ipex: bool = False, + use_bf16: bool = False, + use_amp: bool = False, use_ccl: bool = False, **extra_config): """ @@ -153,6 +155,8 @@ def __init__(self, self._num_processes_for_data_loader = num_processes_for_data_loader self._metrics = TorchMetric(metrics_name, metrics_config) self._use_ipex = use_ipex + self._use_bf16 = use_bf16 + self._use_amp = use_amp self._use_ccl = use_ccl self._extra_config = extra_config @@ -215,10 +219,16 @@ def train_func(config): # get merics metrics = config["metrics"] - if config["use_ipex"]: + + # ipex optimize + use_ipex = config["use_ipex"] + use_bf16 = config["use_bf16"] + use_amp = config["use_amp"] + if use_ipex: import intel_extension_for_pytorch as ipex model = model.to(memory_format=torch.channels_last) - model, optimizer = ipex.optimize(model, optimizer=optimizer) + dtype = torch.bfloat16 if use_bf16 else None + model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype) # create dataset train_data_shard = session.get_dataset_shard("train") @@ -242,11 +252,13 @@ def train_func(config): loss_results = [] for epoch in range(config["num_epochs"]): train_res, train_loss = TorchEstimator.train_epoch(train_dataset, model, loss, - optimizer, metrics, lr_scheduler) + optimizer, metrics, use_amp, + use_bf16, lr_scheduler) session.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss)) if config["evaluate"]: eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset, - model, loss, metrics) + model, loss, metrics, + use_amp, use_bf16) session.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss)) loss_results.append(evaluate_loss) if hasattr(model, "module"): @@ -259,13 +271,18 @@ def train_func(config): })) @staticmethod - def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None): + def train_epoch(dataset, model, criterion, optimizer, metrics, use_amp, use_bf16, scheduler=None): model.train() train_loss, data_size, batch_idx = 0, 0, 0 for batch_idx, (inputs, targets) in enumerate(dataset): # Compute prediction error - outputs = model(inputs) - loss = criterion(outputs, targets) + if use_amp and use_bf16: + with torch.cpu.amp.autocast(): + outputs = model(inputs) + loss = criterion(outputs, targets) + else: + outputs = model(inputs) + loss = criterion(outputs, targets) train_loss += loss.item() metrics.update(outputs, targets) data_size += targets.size(0) @@ -282,14 +299,19 @@ def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None): return train_res, train_loss @staticmethod - def evaluate_epoch(dataset, model, criterion, metrics): + def evaluate_epoch(dataset, model, criterion, metrics, use_amp, use_bf16): model.eval() test_loss, data_size, batch_idx = 0, 0, 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(dataset): # Compute prediction error - outputs = model(inputs) - test_loss += criterion(outputs, targets).item() + if use_amp and use_bf16: + with torch.cpu.amp.autocast(): + outputs = model(inputs) + test_loss += criterion(outputs, targets) + else: + outputs = model(inputs) + test_loss += criterion(outputs, targets).item() metrics.update(outputs, targets) data_size += targets.size(0) @@ -317,6 +339,8 @@ def fit(self, "evaluate": True, "metrics": self._metrics, "use_ipex": self._use_ipex, + "use_bf16": self._use_bf16, + "use_amp": self._use_amp } scaling_config = ScalingConfig(num_workers=self._num_workers, resources_per_worker=self._resources_per_worker) From 8d8ec719284e3550da4defd279269edf8e198153 Mon Sep 17 00:00:00 2001 From: KepingYan Date: Thu, 9 Feb 2023 18:43:03 +0800 Subject: [PATCH 3/7] fix lint --- python/raydp/torch/estimator.py | 9 +++++---- python/raydp/torch/torch_ccl_config.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index 6f6237b1..9ac81520 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -26,6 +26,7 @@ from raydp.torch.torch_metrics import TorchMetric from raydp import stop_spark from raydp.spark import spark_dataframe_to_ray_dataset +from raydp.torch.torch_ccl_config import CCLConfig import ray from ray import train @@ -34,7 +35,6 @@ from ray.air.checkpoint import Checkpoint from ray.air import session from ray.data.dataset import Dataset -from raydp.torch.torch_ccl_config import CCLConfig class TorchEstimator(EstimatorInterface, SparkEstimatorInterface): """ @@ -225,6 +225,7 @@ def train_func(config): use_bf16 = config["use_bf16"] use_amp = config["use_amp"] if use_ipex: + # pylint: disable=import-outside-toplevel import intel_extension_for_pytorch as ipex model = model.to(memory_format=torch.channels_last) dtype = torch.bfloat16 if use_bf16 else None @@ -271,7 +272,8 @@ def train_func(config): })) @staticmethod - def train_epoch(dataset, model, criterion, optimizer, metrics, use_amp, use_bf16, scheduler=None): + def train_epoch(dataset, model, criterion, optimizer, metrics, use_amp, use_bf16, + scheduler=None): model.train() train_loss, data_size, batch_idx = 0, 0, 0 for batch_idx, (inputs, targets) in enumerate(dataset): @@ -354,9 +356,8 @@ def fit(self, train_loop_config["evaluate"] = False else: datasets["evaluate"] = evaluate_ds - if self._use_ccl: - torch_config = CCLConfig(backend='ccl') + torch_config = CCLConfig(backend="ccl") else: torch_config = None self._trainer = TorchTrainer(TorchEstimator.train_func, diff --git a/python/raydp/torch/torch_ccl_config.py b/python/raydp/torch/torch_ccl_config.py index 17804b33..9ac1d598 100644 --- a/python/raydp/torch/torch_ccl_config.py +++ b/python/raydp/torch/torch_ccl_config.py @@ -1,7 +1,7 @@ from ray.train.torch.config import _TorchBackend from ray.train.torch.config import TorchConfig -from dataclasses import dataclass from ray.train._internal.worker_group import WorkerGroup +from dataclasses import dataclass @dataclass @@ -12,6 +12,7 @@ def backend_cls(self): return EnableCCLBackend def ccl_import(): + # pylint: disable=import-outside-toplevel import oneccl_bindings_for_pytorch class EnableCCLBackend(_TorchBackend): @@ -19,4 +20,4 @@ class EnableCCLBackend(_TorchBackend): def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): for i in range(len(worker_group)): worker_group.execute_single_async(i, ccl_import) - super().on_start(worker_group, backend_config) \ No newline at end of file + super().on_start(worker_group, backend_config) From 8469f3347c818df7289e9ee65a0010c198531d24 Mon Sep 17 00:00:00 2001 From: KepingYan Date: Mon, 13 Feb 2023 16:01:51 +0800 Subject: [PATCH 4/7] add opt jit.trace --- examples/pytorch_nyctaxi.py | 3 +- python/raydp/torch/estimator.py | 65 ++++++++++++++++++++++----------- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/examples/pytorch_nyctaxi.py b/examples/pytorch_nyctaxi.py index 0e1b3a8d..df516cc4 100644 --- a/examples/pytorch_nyctaxi.py +++ b/examples/pytorch_nyctaxi.py @@ -73,7 +73,8 @@ def forward(self, x): label_column="fare_amount", label_type=torch.float, batch_size=64, num_epochs=30, metrics_name=["MeanAbsoluteError", "MeanSquaredError"], - use_ipex=False, use_bf16=False, use_amp=False, use_ccl=False) + use_ipex=True, use_bf16=False, use_amp=False, + use_ccl=True, use_jit_trace=True) # Train the model estimator.fit_on_spark(train_df, test_df) # Get the trained model diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index 9ac81520..63b1b0ff 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -94,6 +94,7 @@ def __init__(self, use_bf16: bool = False, use_amp: bool = False, use_ccl: bool = False, + use_jit_trace: bool = False, **extra_config): """ :param num_workers: the number of workers to do the distributed training @@ -136,6 +137,12 @@ def __init__(self, :param metrics_config: the optional config for the metrics. Its format is: {"metric_name_1": {"param1": value1, "param2": value2}, "metric_name_2":{}}, where param is the parameter corresponding to a concrete metric class of TorchMetrics. + :param use_ipex: whether to enable ipex optimization + :param use_bf16: whether to cast model parameters to ``torch.bfloat16`` + :param use_amp: whether to enable auto mixed precision + :param use_ccl: whether to use torch_ccl as the backend to initialize default distributed + process group + :param use_jit_trace: whether to use jit.trace to accelerate the model :param extra_config: the extra config will be set to ray.train.torch.TorchTrainer """ self._num_workers = num_workers @@ -158,6 +165,7 @@ def __init__(self, self._use_bf16 = use_bf16 self._use_amp = use_amp self._use_ccl = use_ccl + self._use_jit_trace = use_jit_trace self._extra_config = extra_config if self._num_processes_for_data_loader > 0: @@ -224,6 +232,7 @@ def train_func(config): use_ipex = config["use_ipex"] use_bf16 = config["use_bf16"] use_amp = config["use_amp"] + use_jit_trace = config["use_jit_trace"] if use_ipex: # pylint: disable=import-outside-toplevel import intel_extension_for_pytorch as ipex @@ -254,12 +263,13 @@ def train_func(config): for epoch in range(config["num_epochs"]): train_res, train_loss = TorchEstimator.train_epoch(train_dataset, model, loss, optimizer, metrics, use_amp, - use_bf16, lr_scheduler) + use_bf16, use_jit_trace, + lr_scheduler) session.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss)) if config["evaluate"]: - eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset, - model, loss, metrics, - use_amp, use_bf16) + eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset, model, + loss, metrics, use_amp, + use_bf16, use_jit_trace) session.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss)) loss_results.append(evaluate_loss) if hasattr(model, "module"): @@ -272,19 +282,14 @@ def train_func(config): })) @staticmethod - def train_epoch(dataset, model, criterion, optimizer, metrics, use_amp, use_bf16, + def train_epoch(dataset, model, criterion, optimizer, metrics, use_amp, use_bf16, use_jit_trace, scheduler=None): model.train() train_loss, data_size, batch_idx = 0, 0, 0 for batch_idx, (inputs, targets) in enumerate(dataset): # Compute prediction error - if use_amp and use_bf16: - with torch.cpu.amp.autocast(): - outputs = model(inputs) - loss = criterion(outputs, targets) - else: - outputs = model(inputs) - loss = criterion(outputs, targets) + outputs, loss = TorchEstimator.train_batch(batch_idx, model, inputs, targets, criterion, + use_amp, use_bf16, use_jit_trace) train_loss += loss.item() metrics.update(outputs, targets) data_size += targets.size(0) @@ -301,19 +306,16 @@ def train_epoch(dataset, model, criterion, optimizer, metrics, use_amp, use_bf16 return train_res, train_loss @staticmethod - def evaluate_epoch(dataset, model, criterion, metrics, use_amp, use_bf16): + def evaluate_epoch(dataset, model, criterion, metrics, use_amp, use_bf16, use_jit_trace): model.eval() test_loss, data_size, batch_idx = 0, 0, 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(dataset): # Compute prediction error - if use_amp and use_bf16: - with torch.cpu.amp.autocast(): - outputs = model(inputs) - test_loss += criterion(outputs, targets) - else: - outputs = model(inputs) - test_loss += criterion(outputs, targets).item() + outputs, loss = TorchEstimator.train_batch(batch_idx, model, inputs, targets, + criterion, use_amp, use_bf16, + use_jit_trace, is_eval=True) + test_loss += loss.item() metrics.update(outputs, targets) data_size += targets.size(0) @@ -322,6 +324,26 @@ def evaluate_epoch(dataset, model, criterion, metrics, use_amp, use_bf16): metrics.reset() return eval_res, test_loss + @staticmethod + def train_batch(batch_idx, model, inputs, targets, criterion, use_amp, use_bf16, use_jit_trace, + is_eval=False): + if use_amp and use_bf16: + with torch.cpu.amp.autocast(): + if use_jit_trace and batch_idx==0: + model = torch.jit.trace(model, inputs) + if is_eval: + model = torch.jit.freeze(model) + outputs = model(inputs) + loss = criterion(outputs, targets) + else: + if use_jit_trace and batch_idx==0: + model = torch.jit.trace(model, inputs) + if is_eval: + model = torch.jit.freeze(model) + outputs = model(inputs) + loss = criterion(outputs, targets) + return outputs, loss + def fit(self, train_ds: Dataset, evaluate_ds: Optional[Dataset] = None, @@ -342,7 +364,8 @@ def fit(self, "metrics": self._metrics, "use_ipex": self._use_ipex, "use_bf16": self._use_bf16, - "use_amp": self._use_amp + "use_amp": self._use_amp, + "use_jit_trace": self._use_jit_trace } scaling_config = ScalingConfig(num_workers=self._num_workers, resources_per_worker=self._resources_per_worker) From 0eaed78faff97b73d9f14c58d779a272bb4a704b Mon Sep 17 00:00:00 2001 From: KepingYan Date: Mon, 13 Feb 2023 16:38:56 +0800 Subject: [PATCH 5/7] fix --- examples/pytorch_nyctaxi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch_nyctaxi.py b/examples/pytorch_nyctaxi.py index df516cc4..f60cfc3a 100644 --- a/examples/pytorch_nyctaxi.py +++ b/examples/pytorch_nyctaxi.py @@ -73,8 +73,8 @@ def forward(self, x): label_column="fare_amount", label_type=torch.float, batch_size=64, num_epochs=30, metrics_name=["MeanAbsoluteError", "MeanSquaredError"], - use_ipex=True, use_bf16=False, use_amp=False, - use_ccl=True, use_jit_trace=True) + use_ipex=False, use_bf16=False, use_amp=False, + use_ccl=False, use_jit_trace=False) # Train the model estimator.fit_on_spark(train_df, test_df) # Get the trained model From 21b1e82935547a8c217b97cf7002d20e7b5c3307 Mon Sep 17 00:00:00 2001 From: KepingYan Date: Fri, 24 Feb 2023 17:35:41 +0800 Subject: [PATCH 6/7] rename torchconfig & add ipex package check --- python/raydp/torch/config.py | 42 ++++++++++++++++++++++++++ python/raydp/torch/estimator.py | 13 ++++++-- python/raydp/torch/torch_ccl_config.py | 23 -------------- 3 files changed, 53 insertions(+), 25 deletions(-) create mode 100644 python/raydp/torch/config.py delete mode 100644 python/raydp/torch/torch_ccl_config.py diff --git a/python/raydp/torch/config.py b/python/raydp/torch/config.py new file mode 100644 index 00000000..ca864ce2 --- /dev/null +++ b/python/raydp/torch/config.py @@ -0,0 +1,42 @@ +from ray.train.torch.config import _TorchBackend +from ray.train.torch.config import TorchConfig as RayTorchConfig +from ray.train._internal.worker_group import WorkerGroup +from dataclasses import dataclass +from packaging import version +import importlib_metadata + +@dataclass +class TorchConfig(RayTorchConfig): + + @property + def backend_cls(self): + return EnableCCLBackend + +def ccl_import(): + # pylint: disable=import-outside-toplevel + import oneccl_bindings_for_pytorch + +class EnableCCLBackend(_TorchBackend): + + def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): + for i in range(len(worker_group)): + worker_group.execute_single_async(i, ccl_import) + super().on_start(worker_group, backend_config) + +def check_ipex(): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + try: + _torch_version = importlib_metadata.version("torch") + except importlib_metadata.PackageNotFoundError: + return None, None + + try: + _ipex_version = importlib_metadata.version("intel_extension_for_pytorch") + except importlib_metadata.PackageNotFoundError: + return _torch_version, None + + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + return torch_major_and_minor , ipex_major_and_minor \ No newline at end of file diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index 63b1b0ff..470523d4 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -26,7 +26,7 @@ from raydp.torch.torch_metrics import TorchMetric from raydp import stop_spark from raydp.spark import spark_dataframe_to_ray_dataset -from raydp.torch.torch_ccl_config import CCLConfig +from raydp.torch.config import TorchConfig, check_ipex import ray from ray import train @@ -171,6 +171,15 @@ def __init__(self, if self._num_processes_for_data_loader > 0: raise TypeError("multiple processes for data loader has not supported") + if self._use_ipex: + torch_version, ipex_version = check_ipex() + assert torch_version is not None, "Pytorch is not found. Please install Pytorch." + assert ipex_version is not None, "Intel Extension for PyTorch is not found. "\ + "Please install Intel Extension for PyTorch." + assert torch_version==ipex_version, "Intel Extension for PyTorch {ipex} needs to "\ + "work with PyTorch {ipex}.*, but PyTorch {torch} is found. Please switch to "\ + "the matching version.".format(ipex=ipex_version, torch=torch_version) + self._trainer: TorchTrainer = None self._check() @@ -380,7 +389,7 @@ def fit(self, else: datasets["evaluate"] = evaluate_ds if self._use_ccl: - torch_config = CCLConfig(backend="ccl") + torch_config = TorchConfig(backend="ccl") else: torch_config = None self._trainer = TorchTrainer(TorchEstimator.train_func, diff --git a/python/raydp/torch/torch_ccl_config.py b/python/raydp/torch/torch_ccl_config.py deleted file mode 100644 index 9ac1d598..00000000 --- a/python/raydp/torch/torch_ccl_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from ray.train.torch.config import _TorchBackend -from ray.train.torch.config import TorchConfig -from ray.train._internal.worker_group import WorkerGroup -from dataclasses import dataclass - - -@dataclass -class CCLConfig(TorchConfig): - - @property - def backend_cls(self): - return EnableCCLBackend - -def ccl_import(): - # pylint: disable=import-outside-toplevel - import oneccl_bindings_for_pytorch - -class EnableCCLBackend(_TorchBackend): - - def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): - for i in range(len(worker_group)): - worker_group.execute_single_async(i, ccl_import) - super().on_start(worker_group, backend_config) From b929a3308bd5d1dc84047ce2a0c4464e4ae515d2 Mon Sep 17 00:00:00 2001 From: KepingYan Date: Fri, 24 Feb 2023 17:40:41 +0800 Subject: [PATCH 7/7] fix lint --- python/raydp/torch/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/raydp/torch/config.py b/python/raydp/torch/config.py index ca864ce2..f902f919 100644 --- a/python/raydp/torch/config.py +++ b/python/raydp/torch/config.py @@ -36,7 +36,7 @@ def get_major_and_minor_from_version(full_version): _ipex_version = importlib_metadata.version("intel_extension_for_pytorch") except importlib_metadata.PackageNotFoundError: return _torch_version, None - + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) - return torch_major_and_minor , ipex_major_and_minor \ No newline at end of file + return torch_major_and_minor , ipex_major_and_minor