diff --git a/pytorch_blade/torch_blade/dynamo/__init__.py b/pytorch_blade/torch_blade/dynamo/__init__.py index 4d92ba9372d..cb2a8aca791 100644 --- a/pytorch_blade/torch_blade/dynamo/__init__.py +++ b/pytorch_blade/torch_blade/dynamo/__init__.py @@ -11,10 +11,10 @@ import torch_blade.dynamo.patch_user_defined -from torch._dynamo.optimizations.training import aot_autograd -from torch._dynamo.optimizations.backends import BACKENDS, create_backend -from torch._dynamo.optimizations.subgraph import SubGraph +from torch._dynamo.backends.common import aot_autograd +from torch._dynamo.backends.registry import register_backend from torch._functorch import compilers +from torch._dynamo.utils import torchscript from functorch.compile import min_cut_rematerialization_partition import torch @@ -85,14 +85,13 @@ def _disc_compile(fx_g: fx.GraphModule, inps, use_ts=False, is_training=True) -> return f -@compilers.make_boxed_compiler def disc_compile(fx_g: fx.GraphModule, inps, use_ts=False) -> Callable: return _disc_compile(fx_g, inps, use_ts=False) def disc(fx_g: fx.GraphModule, inps) -> Callable: import tempfile with tempfile.TemporaryDirectory() as tmp: - scripted = SubGraph(fx_g, inps, tmp).scripted + scripted = torchscript(fx_g, inps) torch._C._jit_pass_remove_mutation(scripted.graph) f = torch.jit.freeze(scripted.eval()) cfg = torch_blade.Config() @@ -101,7 +100,6 @@ def disc(fx_g: fx.GraphModule, inps) -> Callable: f = torch_blade.optimize(f, True, tuple(inps)) return f -@compilers.make_boxed_compiler def disc_compile_ts(fx_g: fx.GraphModule, inps, use_ts=False) -> Callable: return _disc_compile(fx_g, inps, use_ts=True) @@ -194,7 +192,6 @@ def _get_disc_decomp(): ] ) return decompositions_dict - aot_disc = aot_autograd( # these are taken from memory_efficient_fusion() fw_compiler=disc_compile, @@ -203,7 +200,6 @@ def _get_disc_decomp(): decompositions=_get_disc_decomp(), partition_fn=min_cut_rematerialization_partition) - aot_disc_debug = aot_autograd( # these are taken from memory_efficient_fusion() fw_compiler=disc_compile_ts, @@ -212,6 +208,6 @@ def _get_disc_decomp(): decompositions=_get_disc_decomp(), partition_fn=min_cut_rematerialization_partition) -BACKENDS["disc"] = disc -BACKENDS["aot_disc"] = aot_disc -BACKENDS["aot_disc_debug"] = aot_disc_debug +register_backend(name="disc", compiler_fn=disc) +register_backend(name="aot_disc", compiler_fn=aot_disc) +register_backend(name="aot_disc_debug", compiler_fn=aot_disc_debug) diff --git a/tao_compiler/mlir/xla/ral/BUILD b/tao_compiler/mlir/xla/ral/BUILD index 36841fe73e6..41657a075fb 100644 --- a/tao_compiler/mlir/xla/ral/BUILD +++ b/tao_compiler/mlir/xla/ral/BUILD @@ -332,7 +332,7 @@ cc_library( "@local_config_cuda//cuda:cuda_driver", "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ - "@org_tensorflow//tensorflow/stream_executor:rocm_platform", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor/rocm:all_runtime", "@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver", "@local_config_rocm//rocm:rocm_headers", ]) + if_cuda_or_rocm([ @@ -370,7 +370,7 @@ tf_gpu_library( ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_driver", ]) + if_rocm_is_configured([ - "//tensorflow/stream_executor/rocm:rocm_driver", + "@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver", ]), alwayslink = 1, ) @@ -429,7 +429,7 @@ cc_library( "@local_config_cuda//cuda:cuda_driver", "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ - "@org_tensorflow//tensorflow/stream_executor:rocm_platform", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor/rocm:all_runtime", "@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver", "@local_config_rocm//rocm:rocm_headers", ]), @@ -485,7 +485,7 @@ cc_library( "@local_config_cuda//cuda:cuda_driver", "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ - "@org_tensorflow//tensorflow/stream_executor:rocm_platform", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor/rocm:all_runtime", "@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver", "@local_config_rocm//rocm:rocm_headers", ]),