Skip to content
Open
18 changes: 7 additions & 11 deletions pytorch_blade/torch_blade/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import torch_blade.dynamo.patch_user_defined

from torch._dynamo.optimizations.training import aot_autograd
from torch._dynamo.optimizations.backends import BACKENDS, create_backend
from torch._dynamo.optimizations.subgraph import SubGraph
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.registry import register_backend
from torch._functorch import compilers
from torch._dynamo.utils import torchscript
from functorch.compile import min_cut_rematerialization_partition

import torch
Expand Down Expand Up @@ -85,14 +85,13 @@ def _disc_compile(fx_g: fx.GraphModule, inps, use_ts=False, is_training=True) ->

return f

@compilers.make_boxed_compiler
def disc_compile(fx_g: fx.GraphModule, inps, use_ts=False) -> Callable:
return _disc_compile(fx_g, inps, use_ts=False)

def disc(fx_g: fx.GraphModule, inps) -> Callable:
import tempfile
with tempfile.TemporaryDirectory() as tmp:
scripted = SubGraph(fx_g, inps, tmp).scripted
scripted = torchscript(fx_g, inps)
torch._C._jit_pass_remove_mutation(scripted.graph)
f = torch.jit.freeze(scripted.eval())
cfg = torch_blade.Config()
Expand All @@ -101,7 +100,6 @@ def disc(fx_g: fx.GraphModule, inps) -> Callable:
f = torch_blade.optimize(f, True, tuple(inps))
return f

@compilers.make_boxed_compiler
def disc_compile_ts(fx_g: fx.GraphModule, inps, use_ts=False) -> Callable:
return _disc_compile(fx_g, inps, use_ts=True)

Expand Down Expand Up @@ -194,7 +192,6 @@ def _get_disc_decomp():
]
)
return decompositions_dict

aot_disc = aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=disc_compile,
Expand All @@ -203,7 +200,6 @@ def _get_disc_decomp():
decompositions=_get_disc_decomp(),
partition_fn=min_cut_rematerialization_partition)


aot_disc_debug = aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=disc_compile_ts,
Expand All @@ -212,6 +208,6 @@ def _get_disc_decomp():
decompositions=_get_disc_decomp(),
partition_fn=min_cut_rematerialization_partition)

BACKENDS["disc"] = disc
BACKENDS["aot_disc"] = aot_disc
BACKENDS["aot_disc_debug"] = aot_disc_debug
register_backend(name="disc", compiler_fn=disc)
register_backend(name="aot_disc", compiler_fn=aot_disc)
register_backend(name="aot_disc_debug", compiler_fn=aot_disc_debug)
8 changes: 4 additions & 4 deletions tao_compiler/mlir/xla/ral/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ cc_library(
"@local_config_cuda//cuda:cuda_driver",
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm_is_configured([
"@org_tensorflow//tensorflow/stream_executor:rocm_platform",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/rocm:all_runtime",
"@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver",
"@local_config_rocm//rocm:rocm_headers",
]) + if_cuda_or_rocm([
Expand Down Expand Up @@ -370,7 +370,7 @@ tf_gpu_library(
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_driver",
]) + if_rocm_is_configured([
"//tensorflow/stream_executor/rocm:rocm_driver",
"@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver",
]),
alwayslink = 1,
)
Expand Down Expand Up @@ -429,7 +429,7 @@ cc_library(
"@local_config_cuda//cuda:cuda_driver",
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm_is_configured([
"@org_tensorflow//tensorflow/stream_executor:rocm_platform",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/rocm:all_runtime",
"@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver",
"@local_config_rocm//rocm:rocm_headers",
]),
Expand Down Expand Up @@ -485,7 +485,7 @@ cc_library(
"@local_config_cuda//cuda:cuda_driver",
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm_is_configured([
"@org_tensorflow//tensorflow/stream_executor:rocm_platform",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/rocm:all_runtime",
"@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver",
"@local_config_rocm//rocm:rocm_headers",
]),
Expand Down