diff --git a/file.gdb b/file.gdb new file mode 100644 index 00000000000000..5bba091bbdcaa1 --- /dev/null +++ b/file.gdb @@ -0,0 +1,5 @@ +break core/common_runtime/immutable_executor_state.cc:89 +continue +break core/common_runtime/direct_session.cc:918 +break core/common_runtime/direct_session.cc:745 +break core/common_runtime/propagator_state.cc:106 \ No newline at end of file diff --git a/recipe.txt b/recipe.txt new file mode 100644 index 00000000000000..30dd6eff463f7c --- /dev/null +++ b/recipe.txt @@ -0,0 +1,159 @@ +========================================= IN A NEW NODE ================================ + +1. Clone your repo + +2. Install Docker + +4. Install VS code extensions etc + +5. Create Docker image using this Dockerfile + +########### Dockerfile for my own project ################ + +FROM tensorflow/tensorflow:latest + +RUN rm -rf /tensorflow +COPY ./tensorflow /tensorflow + +RUN apt-get update && apt-get install clang -y \ + && apt-get install -y gdb \ + && apt-get install -y git + +RUN echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list > /dev/null + +RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key add - +RUN apt update +RUN apt install -y bazel-6.5.0 +RUN apt install -y bazel + +RUN git clone https://github.com/GeorgeVasilakopoulos/tensorflow.git +WORKDIR tensorflow +RUN git checkout recursion + +pip install --upgrade pip setuptools wheel + +bazel build --config=dbg //tensorflow/tools/pip_package:wheel --repo_env=WHEEL_NAME=tensorflow_cpu +pip install bazel-bin/tensorflow/tools/pip_package/wheel_house/tensorflow_cpu-2.17.0-cp311-cp311-linux_x86_64.whl + +############################################################# + +6. Create a container like this + + +###### +docker run -d --restart always -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -w /tensorflow -v $PWD:/mnt \ + -e HOST_PERMS="\\((id -u):\\)(id -g)" my_image bash +###### + +======================================================================================= + +CREATING KELLY'S IMAGE + +1. Clone her repo + +2. Build docker image using the following Dockerfile: + + +# docker build -t kelly_image . + +############# Dockerfile for Kelly's project ################ + +FROM tensorflow/tensorflow:1.4.0 + +RUN rm -rf /tensorflow +COPY ./tensorflow /tensorflow + + + +RUN apt-get update \ + && apt-get install -y curl wget \ + && apt-get install -y software-properties-common \ + && apt-get install -y unzip \ + && apt-get install -y git \ + && apt-get install -y gcc g++ \ + && apt-get install -y gdb +########################################################## + + +4. RUN + +# apt-get upgrade + + +5. Install Conda and create virtual environment: + +# cd +# wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh +# chmod +x Anaconda3-2022.05-Linux-x86_64.sh +# ./Anaconda3-2022.05-Linux-x86_64.sh + +... + +# conda create -n venv pip python=3.7 +# conda activate venv + +6. Install some stuff: + +pip install -U --user pip six numpy wheel setuptools mock future>=0.17.1 +pip install -U --user keras_applications==1.0.6 --no-deps +pip install -U --user keras_preprocessing==1.0.5 --no-deps + + +7. Install bazel: + +# cd +# wget https://raw.githubusercontent.com/acharal/tensorflow/recursive-functions/tensorflow/tools/ci_build/install/install_bazel.sh + +# chmod +x install_bazel.sh +# ./install_bazel.sh + + +8. In tensorflow/workspace.bzl change the installation of cython to + + + +############## +native.new_http_archive( + name = "cython", + sha256 = "94916d1ede67682638d3cc0feb10648ff14dc51fb7a7f147f4fedce78eaaea97", + urls = [ + "https://files.pythonhosted.org/packages/f0/66/6309291b19b498b672817bd237caec787d1b18013ee659f17b1ec5844887/Cython-0.29.tar.gz", + ], + strip_prefix = "Cython-0.29", + build_file = str(Label("//third_party:cython.BUILD")), + ) +############## + + +9. Build Tensorflow as follows: + + +# ./configure +# bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package --cxxopt="-g" --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-fpermissive" +# bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg/ +# pip3 uninstall -y tensorflow +# pip3 install /tmp/tensorflow_pkg/tensorflow-1.4.2-cp37-cp37m-linux_x86_64.whl + + + + + + +Comments: + +// RUN apt-get install -y software-properties-common +// RUN apt-get install unzip +// RUN apt-get update +// RUN add-apt-repository -y ppa:ubuntu-toolchain-r/test +// // RUN ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /mnt +// RUN apt-get install -y gcc-11 g++-11 + +// RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 60 --slave /usr/bin/g++ g++ /usr/bin/g++-11 +// RUN wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh +// RUN chmod +x Anaconda3-2022.05-Linux-x86_64.sh + +// // RUN conda create -n venv pip python=3.7 +// RUN wget https://raw.githubusercontent.com/acharal/tensorflow/recursive-functions/tensorflow/tools/ci_build/install/install_bazel.sh + + + diff --git a/recursion-tests/comp b/recursion-tests/comp new file mode 100755 index 00000000000000..90fa7e03907f60 --- /dev/null +++ b/recursion-tests/comp @@ -0,0 +1,6 @@ +#!/bin/bash + +cd .. +bazel build --disk_cache=/root/mycache --per_file_copt=+tensorflow.*,-tensorflow/compiler.*,-tensorflow/lite.*,-tensorflow/core/kernels.*@-O0,-g //tensorflow/tools/pip_package:wheel --repo_env=WHEEL_NAME=tensorflow_cpu && +pip uninstall tensorflow_cpu -y && +pip install bazel-bin/tensorflow/tools/pip_package/wheel_house/tensorflow_cpu-2.17.0-cp311-cp311-linux_x86_64.whl diff --git a/recursion-tests/distributed/d2.py b/recursion-tests/distributed/d2.py new file mode 100644 index 00000000000000..0a1114671dc8f5 --- /dev/null +++ b/recursion-tests/distributed/d2.py @@ -0,0 +1,13 @@ +import os + +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' + +import tensorflow as tf + + +cluster_spec = { + "local": ["localhost:2222", "localhost:2223"] + } + +server = tf.distribute.Server(cluster_spec, job_name="local", task_index=0) +server.join() \ No newline at end of file diff --git a/recursion-tests/distributed/d3.py b/recursion-tests/distributed/d3.py new file mode 100644 index 00000000000000..d2a190329de22a --- /dev/null +++ b/recursion-tests/distributed/d3.py @@ -0,0 +1,14 @@ +import os + + +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' + +import tensorflow as tf + + +cluster_spec = { + "local": ["localhost:2222", "localhost:2223"] + } + +server = tf.distribute.Server(cluster_spec, job_name="local", task_index=1) +server.join() diff --git a/recursion-tests/distributed/distr.py b/recursion-tests/distributed/distr.py new file mode 100644 index 00000000000000..615aaa61cad69e --- /dev/null +++ b/recursion-tests/distributed/distr.py @@ -0,0 +1,50 @@ +import os + + +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' + +import tensorflow as tf +from tensorflow.python.framework import function + +tf.compat.v1.disable_eager_execution() +tf.compat.v1.disable_control_flow_v2() + + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) + +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + + + + + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + + def f1(): + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + ret = tf.constant(1) + return ret + def f2(): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + ret = n * fac(n - 1) + return ret + + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + pred = tf.less_equal(n, 1) + + return tf.cond(pred, f1, f2) + +FacImpl.add_to_graph(tf.compat.v1.get_default_graph()) + +n = tf.constant(10) +x = fac(n) + +#print(tf.get_default_graph().as_graph_def()) + +# writer = tf.compat.v1.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) + +with tf.compat.v1.Session("grpc://localhost:2222") as sess: + print(sess.run(x)) + +# writer.close() diff --git a/recursion-tests/exponents.py b/recursion-tests/exponents.py new file mode 100644 index 00000000000000..f107979ec154a3 --- /dev/null +++ b/recursion-tests/exponents.py @@ -0,0 +1,52 @@ +import os +import tensorflow as tf +from tensorflow.python.framework import function + +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' + +tf.compat.v1.disable_eager_execution() +tf.compat.v1.disable_control_flow_v2() + +exp = function.Declare("EXPONENT", [("x", tf.float32), ("n", tf.int32)], [("ret", tf.float32)]) + + + + +# fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.float32, tf.int32, func_name="EXPONENT", out_names=["ret"]) +def ExpImpl(x, n): + return tf.cond(tf.equal(n,0), + lambda: tf.constant(1.0), + lambda: x*exp(x,n-1)) + + +# @function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +# def FacImpl2(n): +# return t(1) + + +ExpImpl.add_to_graph(tf.compat.v1.get_default_graph()) +# t.add_to_graph(tf.compat.v1.get_default_graph()) +# FacImpl2.add_to_graph(tf.compat.v1.get_default_graph()) + + +x = tf.compat.v1.get_variable('n_var', [], initializer=tf.constant_initializer(4.0)) +y = ExpImpl(x,2) + +train_op = tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize(y) +print(tf.compat.v1.get_default_graph().as_graph_def()) + + +sess = tf.compat.v1.Session() +sess.run(tf.compat.v1.initialize_all_variables()) +print(x.eval(session=sess)) +print(sess.run(train_op)) +print(x.eval(session=sess)) + +# writer = tf.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) + +# with tf.compat.v1.Session() as sess: +# result = ExpImpl(2,5) +# print("Result:", sess.run(result)) + diff --git a/recursion-tests/factorial.py b/recursion-tests/factorial.py new file mode 100644 index 00000000000000..f5ca7d668de073 --- /dev/null +++ b/recursion-tests/factorial.py @@ -0,0 +1,15 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +# fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant(1), + lambda: n * FacImpl(n - 1)) + + +print(FacImpl(5)) + + diff --git a/recursion-tests/nohup.out b/recursion-tests/nohup.out new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/recursion-tests/takeuchi.py b/recursion-tests/takeuchi.py new file mode 100644 index 00000000000000..0da79c39f64648 --- /dev/null +++ b/recursion-tests/takeuchi.py @@ -0,0 +1,22 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +tf.compat.v1.disable_eager_execution() +tf.compat.v1.disable_control_flow_v2() + +tak = function.Declare("Tak", [("x", tf.int32), ("y", tf.int32), ("z", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, tf.int32, tf.int32, func_name="Tak", out_names=["ret"]) +def TakImpl(x,y,z): + return tf.cond(tf.less(y, x), + lambda: tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y)), + lambda: z) + +TakImpl.add_to_graph(tf.compat.v1.get_default_graph()) + + +with tf.compat.v1.Session() as sess: + result = TakImpl(24,16,8) + print("Result:", sess.run(result)) + +#print(tf.get_default_graph().as_graph_def()) \ No newline at end of file diff --git a/recursion-tests/test.py b/recursion-tests/test.py new file mode 100644 index 00000000000000..0cd496e37ce579 --- /dev/null +++ b/recursion-tests/test.py @@ -0,0 +1,35 @@ +# import os + +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + + +import tensorflow as tf +from tensorflow.python.framework import function + + +tf.compat.v1.disable_eager_execution() + +# tf.logging.set_verbosity(tf.logging.INFO) +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + + +@function.Defun(tf.int32, func_name="Test", out_names=["ret"]) +def t(n): + return tf.constant(1) + + + +# fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return t(n) + + + +FacImpl.add_to_graph(tf.compat.v1.get_default_graph()) + +with tf.compat.v1.Session() as sess: + result = FacImpl(1) + print("Result:", sess.run(result)) + diff --git a/recursion-tests/test2.py b/recursion-tests/test2.py new file mode 100644 index 00000000000000..673c56675ac30f --- /dev/null +++ b/recursion-tests/test2.py @@ -0,0 +1,49 @@ +import os + +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' +# os.environ['TF_DUMP_GRAPH_NAME_FILTER'] = 'Fac' +import tensorflow as tf +from tensorflow.python.framework import function + + + +tf.compat.v1.disable_eager_execution() +tf.compat.v1.disable_control_flow_v2() + +# tf.logging.set_verbosity(tf.logging.INFO) +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + + + + +# fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant(1), + lambda: n * fac(n - 1)) + + +# @function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +# def FacImpl2(n): +# return t(1) + + +FacImpl.add_to_graph(tf.compat.v1.get_default_graph()) +# t.add_to_graph(tf.compat.v1.get_default_graph()) +# FacImpl2.add_to_graph(tf.compat.v1.get_default_graph()) + + +print(tf.compat.v1.get_default_graph().as_graph_def()) + + +writer = tf.compat.v1.summary.FileWriter('/tensorflow/recursion-tests/graph', tf.compat.v1.get_default_graph()) +# writer = tf.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) + +with tf.compat.v1.Session() as sess: + result = FacImpl(10) + print("Result:", sess.run(result)) + diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 9812b0a7dfcef3..14992dda6b44ca 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tstring.h" +#include "tensorflow/core/framework/function.pb.h" // -------------------------------------------------------------------------- // C API for TensorFlow. // @@ -860,6 +861,11 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status); + + +TF_CAPI_EXPORT extern void TF_GraphAddFunctionDef(TF_Graph* g, const void* proto, size_t proto_len, TF_Status* status); + + // Adds a copy of function `func` and optionally its gradient function `grad` // to `g`. Once `func`/`grad` is added to `g`, it can be called by creating // an operation using the function's name. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 25805954eff67c..a11cdf28b33964 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -252,6 +252,24 @@ const char* TF_FunctionName(TF_Function* func) { return func->record->fdef().signature().name().c_str(); } + +void TF_GraphAddFunctionDef(TF_Graph* g, const void* proto, size_t proto_len, TF_Status* status){ + + tensorflow::mutex_lock l(g->mu); + tensorflow::FunctionDef fdef; + bool success = fdef.ParseFromArray(proto, proto_len); + if (!success) { + status->status = InvalidArgument( + "Invalid FunctionDef given to TF_GraphAddFunctionDef"); + return; + } + + + tensorflow::StackTracesMap stack_traces; + status->status = g->graph.AddFunctionDef(fdef,std::move(stack_traces)); +} + + void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, const TF_Function* grad, TF_Status* status) { if (func == nullptr) { diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index e2adb15245c183..a6749b3aeb915c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -579,6 +579,7 @@ cc_library( "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:filesystem_ops", + "//tensorflow/core/kernels:function_control_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:grappler", @@ -915,6 +916,7 @@ filegroup( "encode_proto_ops_op_lib", "experimental_dataset_ops_op_lib", "filesystem_ops_op_lib", + "function_control_ops_op_lib", "function_ops_op_lib", "functional_grad", "functional_ops_op_lib", @@ -1764,7 +1766,7 @@ alias( tf_cuda_library( name = "graph", srcs = ["//tensorflow/core/graph:graph_srcs"], - hdrs = ["//tensorflow/core/graph:graph_headers"], + hdrs = ["//tensorflow/core/graph:graph_headers", "//tensorflow/core/common_runtime:graph_constructor.h"], deps = [ ":framework", ":framework_internal", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 50285f87b2283c..c90606be3c421e 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1722,7 +1722,7 @@ Status DirectSession::CreateGraphs( for (auto& partition : partitions) { std::unique_ptr device_graph( - new Graph(client_graph->flib_def.get())); + new Graph(OpRegistry::Global())); device_graph->SetConstructionContext(ConstructionContext::kDirectSession); GraphConstructorOptions device_opts; // There are internal operations (e.g., send/recv) that we now allow. diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 3789b94a20757e..8cab5632b8f29d 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" @@ -1541,17 +1542,24 @@ std::unique_ptr SymbolicGradientHelper::Compute() { g)); // Remove the old return nodes from the function body. - for (Node* n : gbody->ret_nodes) { - g->RemoveNode(n); - } - gbody->ret_types = fbody_->arg_types; + // for (Node* n : gbody->ret_nodes) { + // g->RemoveNode(n); + // } + // gbody->ret_types = fbody_->arg_types; + + // Concatenate vectors + gbody->ret_types.insert(gbody->ret_types.end(), fbody_->arg_types.begin(), fbody_->arg_types.end()); + + printf("After adding gradients:\n", SummarizeGraphDef(g->ToGraphDefDebug()).c_str()); + + // TODO(apassos): use the right dtype for gradients of resource variables for (int i = 0; i < gbody->ret_types.size(); ++i) { if (gbody->ret_types[i] == DT_RESOURCE) { gbody->ret_types[i] = DT_FLOAT; } } - gbody->ret_nodes.clear(); + // gbody->ret_nodes.clear(); // Add new return nodes to the function gradient body for each node // in 'x_grad_nodes'. const int arg_types_size = static_cast(fbody_->arg_types.size()); diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 66109aee89eaa9..49fd0f9c032bdd 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -63,6 +63,15 @@ namespace { // can skip expensive duplicates check in 'AddControlEdge'. static constexpr const bool kDoNotCheckDuplicates = true; +inline bool IsCall(const NodeDef& node_def){ + return node_def.op() == "Call" || node_def.op() == "RefCall"; +} + +inline bool IsReturn(const NodeDef& node_def){ + return node_def.op() == "Return" || node_def.op() == "RefReturn"; +} + + inline bool IsMerge(const NodeDef& node_def) { return node_def.op() == "Merge" || node_def.op() == "RefMerge" || node_def.op() == "_XlaMerge"; @@ -201,6 +210,7 @@ class GraphConstructor { TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies()); TF_RETURN_IF_ERROR(BuildNodeIndex()); + TF_RETURN_IF_ERROR(PopulateFunctionReturningNodes()); TF_RETURN_IF_ERROR(InitFromEdges()); // NOTE: Convert() invokes `consume_node_def()` on each node in the input @@ -228,6 +238,7 @@ class GraphConstructor { Status PopulateReturnTensors(); Status PopulateReturnNodes(); Status PopulateMissingUnusedInputMapKeys(); + Status PopulateFunctionReturningNodes(); FunctionDefLibraryStackTraces CreateStackTracesForFunctionDefLibrary( const FunctionDefLibrary& library) const; @@ -261,6 +272,10 @@ class GraphConstructor { void AddPrefixToNodeDef(const std::vector& input_already_exists, NodeDef* node_def); + bool IsReturningNode(const NodeDef& node_def){ + return (function_returning_nodes_.find(node_def.name()) != function_returning_nodes_.end()); + } + // Modifies `node_def` if its name isn't unique, or if any of its inputs' // names have been uniquified. This must be called in topological order on all // nodes. @@ -286,7 +301,7 @@ class GraphConstructor { // Decrement pending count for users of `processed` and add the ones that now // have all of their pending inputs satisfied to `ready_`. - void UpdatePendingCountAndReady(int processed, bool is_next_iteration); + void UpdatePendingCountAndReady(int processed, bool is_next_iteration, bool is_function_call); // Subclasses override the following virtual methods to provide efficient // access to the original protocol buffer-based graph. @@ -405,6 +420,7 @@ class GraphConstructor { int dst_index; }; std::vector back_edges_; + std::unordered_set function_returning_nodes_; GraphConstructor(const GraphConstructor&) = delete; void operator=(const GraphConstructor&) = delete; @@ -560,20 +576,21 @@ Status MaybeAppendVersionWarning(const VersionDef* versions, } void GraphConstructor::UpdatePendingCountAndReady(int processed, - bool is_next_iteration) { + bool is_next_iteration, bool is_function_call) { for (size_t i = 0; i < outputs_[processed].size(); ++i) { const int output = outputs_[processed][i]; // We didn't consider NextIteration->Merge edges when computing // pending_counts_ so we should not have to consider it here either. bool is_next_iteration_to_merge_edge = is_next_iteration && merge_node_indices_.count(output) == 1; - if (!is_next_iteration_to_merge_edge) { - int* current_pending_count = &pending_count_[output]; - CHECK_GT(*current_pending_count, 0); - (*current_pending_count)--; - if (*current_pending_count == 0) { - ready_.insert(output); - } + if (is_next_iteration_to_merge_edge)continue; + int* current_pending_count = &pending_count_[output]; + if (*current_pending_count == 0 && is_function_call) continue; + if (*current_pending_count == 0 && merge_node_indices_.count(output) == 1) continue; + // CHECK_GT(*current_pending_count, 0); + (*current_pending_count)--; + if (*current_pending_count == 0) { + ready_.insert(output); } } } @@ -646,6 +663,44 @@ Status GraphConstructor::EnsureNoNameCollisions() { return absl::OkStatus(); } +Status GraphConstructor::PopulateFunctionReturningNodes() { + std::unordered_map> returning_nodes; + for (int n = 0; n < node_def_count(); ++n) { + const NodeDef& node_def = get_node_def(n); + if (IsReturn(node_def)){ + // Nodes that send their output to "Return" nodes are + // function Returning Nodes and in case of recursive functions + // those nodes are part of graph cycles. + for (const auto& input_name : node_def.input()) { + // In order to detect the recursion cycles we depend on + // the fact that a recursive function's returning node, + // will be sending outputs to at least 2 "Return" nodes + // with different "call_id" attributes (same "call_id" + // attrs would mean that they belong in the same function call + // but they correspond to different function outputs) + if (!absl::StartsWith(input_name, "^")) { + string prevNode = input_name; + size_t pos = input_name.find(":"); + if (pos != std::string::npos) + prevNode = input_name.substr(0, pos); + + + int call_id; + TF_CHECK_OK(GetNodeAttr(AttrSlice(node_def), "call_id", &call_id)); + returning_nodes[prevNode].emplace(call_id); + } + } + } + } + for (auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + function_returning_nodes_.insert(retnode.first); + } + } + return OkStatus(); +} + Status GraphConstructor::ValidateInputMapAndControlDependencies() { for (const auto& mapping : opts_.input_map) { TensorId src = mapping.first; @@ -729,15 +784,42 @@ Status GraphConstructor::InitFromEdges() { } } + gtl::FlatSet call_nodes; + gtl::FlatSet merge_return_nodes; + for (int n = 0; n < node_def_count(); ++n) { + const NodeDef& node_def = get_node_def(n); + if (IsCall(node_def)) { + call_nodes.insert(node_def.name()); + } + if (!IsMerge(node_def) && IsReturningNode(node_def)){ + for (const auto& input_name : node_def.input()) { + if (!absl::StartsWith(input_name, "^")) { + string prevNode = input_name; + size_t pos = input_name.find(":"); + + if (pos != std::string::npos) + prevNode = input_name.substr(0, pos); + + merge_return_nodes.insert(prevNode); + } + } + } + } + + + + // Parse the inputs for each node. for (int n = 0; n < num_nodes; ++n) { const NodeDef& node_def = get_node_def(n); int pending_count = node_def.input_size(); - if (IsMerge(node_def)) { - // Cycles in the graph are only allowed for while loops. A while loop is - // identified by an edge from a NextIteration node to a Merge node. For - // such Merge nodes, only wait for one non-control input before - // considering the node ready to process in Convert(). + if (IsMerge(node_def) && !IsReturningNode(node_def)) { + // Cycles in the graph are only allowed for while loops and recursion. + // A while loop is identified by an edge from a NextIteration node to a Merge node. + // A recursion is identified by an edge from a Call Node to a Merge node + // In recursion, function returning nodes also participate in a cycle + // For such Merge nodes, and for function returning nodes only wait for + // one non-control input before considering the node ready to process in Convert(). int32_t num_control_edges = 0; bool has_loop_back_edge = false; for (int i = 0; i < node_def.input_size(); ++i) { @@ -747,15 +829,28 @@ Status GraphConstructor::InitFromEdges() { } else { TensorId id(ParseTensorName(input_name)); if (next_iteration_nodes.find(string(id.first)) != - next_iteration_nodes.end()) { + next_iteration_nodes.end()|| + call_nodes.find(string(id.first)) != + call_nodes.end()|| + merge_return_nodes.find(node_def.name()) != + merge_return_nodes.end()) { has_loop_back_edge = true; } } } if (has_loop_back_edge) { - pending_count = num_control_edges + 1; + pending_count = std::min(num_control_edges + 1, node_def.input_size()); } - } + } else if (IsReturningNode(node_def)) { + int num_control_edges = 0; + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name(node_def.input(i)); + if (absl::StartsWith(input_name, "^")) { + num_control_edges++; + } + } + pending_count = std::min(num_control_edges + 1, node_def.input_size()); + } for (int i = 0; i < node_def.input_size(); ++i) { StringPiece input_name = node_def.input(i); TensorId id(ParseTensorName(input_name)); @@ -1218,7 +1313,7 @@ Status GraphConstructor::Convert() { TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped)); if (is_node_mapped) { // Skip this node after updating pending_count_ for outputs - UpdatePendingCountAndReady(o, IsNextIteration(node_def)); + UpdatePendingCountAndReady(o, IsNextIteration(node_def), IsCall(node_def)); continue; } } @@ -1277,10 +1372,10 @@ Status GraphConstructor::Convert() { inputs.emplace_back(string(tensor_id.node()), src_node, src_index); } - if (has_data_back_edge && !IsMerge(node_def)) { + if (has_data_back_edge && !IsMerge(node_def) && !IsReturningNode(node_def)) { return errors::InvalidArgument( "Node '", node_def.name(), - "' had a back edge, but only Merge nodes can have back edges."); + "' had a back edge, but only Merge and returning nodes can have back edges."); } Node* node; @@ -1344,7 +1439,7 @@ Status GraphConstructor::Convert() { TF_RETURN_IF_ERROR(ValidateShape(node)); // Update pending_count_ for outputs. - UpdatePendingCountAndReady(o, node->IsNextIteration()); + UpdatePendingCountAndReady(o, node->IsNextIteration(), node->IsCall()); } if (processed < node_def_count()) { diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index c2e115b80c3cb9..19cfd1c4fbaff5 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -61,6 +61,9 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #endif // IS_MOBILE_PLATFORM +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + namespace tensorflow { namespace { @@ -849,6 +852,29 @@ Status GraphExecutionState::OptimizeGraph( for (Node* node : optimized_graph->get()->nodes()) { node->set_assigned_device_name(node->requested_device()); } + + /*******************************************************************************************/ + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("Fully_Optimized"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = new_graph.ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return tensorflow::errors::ResourceExhausted("Failed to allocate memory to serialize message of type '" + ,new_graph.GetTypeName(), "' and size ", proto_size); + } + new_graph.SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + /*******************************************************************************************/ + printf("Transformation passed successfully"); + + + return absl::OkStatus(); } else { return errors::InvalidArgument("Meta Optimizer disabled"); diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h index 87b3a12891d45a..f118822a9525e5 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.h +++ b/tensorflow/core/common_runtime/graph_execution_state.h @@ -58,7 +58,7 @@ struct ClientGraph { DataTypeVector feed_types, DataTypeVector fetch_types, int64_t collective_graph_key) : flib_def(std::move(flib)), - graph(flib_def.get()), + graph(*flib_def), feed_types(std::move(feed_types)), fetch_types(std::move(fetch_types)), collective_graph_key(collective_graph_key) {} diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h index ed9b14cfa1f73d..e3396f7d36ccd3 100644 --- a/tensorflow/core/common_runtime/graph_view.h +++ b/tensorflow/core/common_runtime/graph_view.h @@ -67,10 +67,14 @@ struct NodeItem { bool is_constant_enter : 1; // True iff IsEnter(node) and // node->GetAttr("is_constant") == true. bool is_exit : 1; // True iff IsExit(node) + bool is_call : 1; // True iff IsCall(node) + bool is_return : 1; // True iff IsReturn(node) bool is_control_trigger : 1; // True iff IsControlTrigger(node) bool is_source : 1; // True iff IsSource(node) // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) bool is_enter_exit_or_next_iter : 1; + // True iff IsCall(node) || IsReturn(node) + bool is_call_or_return : 1; bool is_transfer_node : 1; // True iff IsTransferNode(node) bool is_initialization_op : 1; // True iff IsInitializationOp(node) bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) @@ -107,6 +111,12 @@ struct NodeItem { // Number of output control edges. int32 num_output_control_edges; + string frame_name; // cache the attribute if is_enter | is-exit | is_call | is_return + string dyn_frame_name; // cache the attribute if is_enter | is-exit | is_call | is_return + + int call_id = -1; + + // If non-null, contains an array of num_outputs bools, where the ith bool // is true if and only if the ith output is consumed by another node. std::unique_ptr outputs_required; diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc index e3a2435505e041..2bf3c0ba1ebe41 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.cc +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -96,6 +96,8 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { pending_ids_.resize(gview_.num_nodes()); + std::unordered_map input_count; + // Preprocess every node in the graph to create an instance of op // kernel for each node. requires_control_flow_ = false; @@ -103,6 +105,8 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { if (IsSink(n)) continue; if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) { requires_control_flow_ = true; + } else if(IsCall(n) || IsReturn(n)){ + requires_control_flow_ = true; } else if (IsRecv(n)) { // A Recv node from a different device may produce dead tensors from // non-local control-flow nodes. @@ -186,6 +190,23 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { } else { item->is_constant_enter = false; } + + item->is_call = IsCall(n); + + if(item->is_call){ + string frame_name; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name)); + FrameInfo* frame_info = frame_info_[frame_name].get(); + frame_info->parallel_iterations = 1; + if (call_frame_info_.size() <= id) { + call_frame_info_.resize(id + 1); + } + call_frame_info_[id] = frame_info; + } + + + item->is_return = IsReturn(n); + item->is_call_or_return = (IsCall(n) || IsReturn(n)); item->is_exit = IsExit(n); item->is_control_trigger = IsControlTrigger(n); item->is_source = IsSource(n); @@ -217,6 +238,19 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { string enter_name; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); EnsureFrameInfo(enter_name)->input_count++; + item->frame_name = enter_name; + item->dyn_frame_name = enter_name; + } + if(item->is_call_or_return){ + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &item->frame_name)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "call_id", &item->call_id)); + item->dyn_frame_name = strings::StrCat(item->call_id); + } + if (item->is_call) { + input_count[item->dyn_frame_name]++; + // The following assumes that all the calls of same function have the same number of inputs + // which is of course apparent for a well-formed graph (produced by the transformation) + EnsureFrameInfo(item->frame_name)->input_count = input_count[item->dyn_frame_name]; } // Record information about whether each output of the op is used. @@ -303,6 +337,7 @@ Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, } } + std::unordered_map call_id_to_call_node_id; while (!ready.empty()) { Node* curr_node = ready.front(); int curr_id = curr_node->id(); @@ -323,6 +358,31 @@ Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, } frame_name = cf_info->frame_names[parent->id()]; parent = parent_nodes[parent->id()]; + } else if (IsCall(curr_node)) { + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name)); + + int call_id; + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(),"call_id", &call_id)); + // we assume that call_id is unique and we don't need to concat with frame_name + // to make it unique. + call_id_to_call_node_id.emplace(call_id, curr_id); + parent = curr_node; + } else if (IsReturn(curr_node)) { + int call_id; + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(), "call_id", &call_id)); + + auto it = call_id_to_call_node_id.find(call_id); + if (it != call_id_to_call_node_id.end()) { + int call_node_id = it->second; + parent = parent_nodes[call_node_id]; + frame_name = cf_info->frame_names[call_node_id]; + } else { + ready.push_back(curr_node); + continue; + } } else { parent = parent_nodes[curr_id]; frame_name = cf_info->frame_names[curr_id]; @@ -331,6 +391,7 @@ Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, for (const Edge* out_edge : curr_node->out_edges()) { Node* out = out_edge->dst(); if (IsSink(out)) continue; + if (IsReturn(out) && out_edge->IsControlEdge()) continue; const int out_id = out->id(); // Add to ready queue if not visited. diff --git a/tensorflow/core/common_runtime/immutable_executor_state.h b/tensorflow/core/common_runtime/immutable_executor_state.h index a1fca080ca6c5c..320da16860c5d5 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.h +++ b/tensorflow/core/common_runtime/immutable_executor_state.h @@ -99,6 +99,13 @@ class ImmutableExecutorState { return *enter_frame_info_[node_item.node_id]; } + const FrameInfo& get_call_frame_info(const NodeItem& node_item) const { + DCHECK(node_item.is_call); + return *call_frame_info_[node_item.node_id]; + } + + + bool requires_control_flow_support() const { return requires_control_flow_; } // Copies the pending counts for nodes in this graph to the given array. @@ -147,6 +154,7 @@ class ImmutableExecutorState { // If the graph contains any "Enter" or "RefEnter" nodes, this vector maps // dense node IDs to the corresponding FrameInfo. std::vector enter_frame_info_; + std::vector call_frame_info_; // If `requires_control_flow_` is false, this points to an array of initial // pending counts for the nodes in the graph, indexed by node ID. diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index 9a365177770d4a..67a195ee72cc58 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -7,7 +7,7 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, +distributed under the License is distributed on a:n "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. @@ -42,13 +42,13 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state, 0, new PropagatorState::IterationState(0, root_frame_->pending_counts, root_frame_->total_input_tensors)); - outstanding_frames_.emplace(root_frame_->frame_id, root_frame_); + // outstanding_frames_.emplace(root_frame_->frame_id, root_frame_); } PropagatorState::~PropagatorState() { - for (auto name_frame : outstanding_frames_) { - delete name_frame.second; - } + // for (auto name_frame : outstanding_frames_) { + // delete name_frame.second; + // } } void PropagatorState::ActivateRoots(gtl::ArraySlice roots, @@ -89,7 +89,21 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, FrameState* output_frame = input_frame; IterationState* output_iter = input_iter; - if (!item->is_enter_exit_or_next_iter) { + // if (!item->is_enter_exit_or_next_iter) { + // if (vlog_) { + // VLOG(2) << "Propagate Outputs: " << node->name(); + // VLOG(2) << "Frame: " << input_frame->frame_name; + // } + // printf("Propagate Outputs: %s, am i alive? %d\n", node->name().c_str(), !is_dead); + // printf("Frame: %s\n", input_frame->frame_name.c_str()); + + string output(tagged_node.node_item->kernel->name_view()); + + // printf("Propagate Outputs: %s, am i alive? %d\n",output.c_str(), !is_dead); + // printf("Frame: %s\n", input_frame->frame_name.c_str()); + + + if (!item->is_enter_exit_or_next_iter && !item->is_call_or_return) { // Fast path for node types that don't need special handling. // This is the case for most nodes. DCHECK_EQ(input_frame, output_frame); @@ -131,6 +145,37 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, /*decrement_activation=*/0); is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); } + } else if (item->is_call) { + // if (is_dead) { + // // Stop the deadness propagation. + // output_frame = nullptr; + // } else { + FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); + // printf("Inside Call: %s. Input frame id: %d, Output frame id %d\n", output.c_str(),input_frame->frame_id,output_frame->frame_id); + output_iter = output_frame->GetIteration(0); + { + mutex_lock l(output_frame->mu); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); + output_frame->num_pending_inputs--; + } + is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); + } else if (item->is_return) { + // if (is_dead) { + // // Stop the deadness propagation. + // output_frame = nullptr; + // } else { + output_frame = input_frame->parent_frame; + output_iter = input_frame->parent_iter; + // printf("Inside Return: %s. Input frame id: %d, Output frame id %d\n", output.c_str(),input_frame->frame_id,output_frame->frame_id); + { + mutex_lock l(output_frame->mu); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); + } + is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); } else { DCHECK(item->is_next_iteration); if (is_dead) { @@ -244,11 +289,11 @@ void PropagatorState::DumpIterationState(const FrameState* frame, void PropagatorState::DumpState() { mutex_lock l(mu_); LOG(WARNING) << "Dumping state"; - for (auto& frame : outstanding_frames_) { - LOG(WARNING) << frame.first; - FrameState* frame_state = frame.second; - frame_state->DumpIterationState(this); - } + // for (auto& frame : outstanding_frames_) { + // LOG(WARNING) << frame.first; + // FrameState* frame_state = frame.second; + // frame_state->DumpIterationState(this); + // } } void PropagatorState::FindOrCreateChildFrame(FrameState* frame, @@ -257,16 +302,16 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, FrameState** child) { // Get the child frame name. const ImmutableExecutorState::FrameInfo& frame_info = - immutable_state_.get_enter_frame_info(node_item); + node_item.is_enter ? immutable_state_.get_enter_frame_info(node_item) : immutable_state_.get_call_frame_info(node_item); const uint64 child_id = Hash64Combine( frame->frame_id, - Hash64Combine(iter_state->iter_num, Hash64(frame_info.name))); + Hash64Combine(iter_state->iter_num, Hash64(frame_info.name + ":" + std::to_string(node_item.call_id)))); { - tf_shared_lock executor_lock(mu_); - auto it = outstanding_frames_.find(child_id); - if (it != outstanding_frames_.end()) { + tf_shared_lock executor_lock(frame->mu); + auto it = frame->outstanding_child_frames_.find(child_id); + if (it != frame->outstanding_child_frames_.end()) { *child = it->second; return; } @@ -285,6 +330,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, temp->frame_id = child_id; temp->parent_frame = frame; temp->parent_iter = iter_state; + temp->call_id = node_item.call_id; temp->InitializeFrameInfo(frame_info); // Initialize iteration 0. @@ -295,14 +341,13 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, } { - mutex_lock executor_lock(mu_); - auto it = outstanding_frames_.find(child_id); - if (it != outstanding_frames_.end()) { + mutex_lock executor_lock(frame->mu); + auto it = frame->outstanding_child_frames_.find(child_id); + if (it != frame->outstanding_child_frames_.end()) { *child = it->second; } else { - mutex_lock frame_lock(frame->mu); iter_state->outstanding_frame_count++; - outstanding_frames_[child_id] = temp; + frame->outstanding_child_frames_[child_id] = temp; *child = temp; temp = nullptr; } @@ -382,8 +427,10 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { // Delete the frame. if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id; { - mutex_lock executor_lock(mu_); - outstanding_frames_.erase(frame->frame_id); + if (parent_frame != nullptr) { + mutex_lock parent_frame_lock(parent_frame->mu); + parent_frame->outstanding_child_frames_.erase(frame->frame_id); + } } delete frame; } @@ -551,6 +598,12 @@ int PropagatorState::FrameState::ActivateNodesSlowPathInternal( dst_ready = (adjust_result.pending_count == 1) && dst_dead; } } else { + if (dst_item->is_return) { + // In case of "Return" dst_node, + // we compare node's frame attr with current frame name + // if they are different, ignore this op + if (dst_item->call_id != call_id) continue; + } // Handle all other (non-merge) nodes. // We need to set the input of the op before adjusting activation. @@ -572,6 +625,23 @@ int PropagatorState::FrameState::ActivateNodesSlowPathInternal( increment_dead); dst_dead = adjust_result.dead_count > 0; dst_ready = !(adjust_result.pending_count > 0); + + if (dst_item->is_return && increment_dead) { + // The only dead input a Return op will ever may get + // is the control input propagated to it from a corresponding + // dead Call op in case of untaken branch. So at this point + // we are certain that Return op will never receive another input. + // Therefore, we force it to be added in queue for the sake of + // deadness propagation and we adjust it for activation once more, + // so that it no longer waits for another (never coming) input. + const PendingCounts::AdjustResult adjust_result = + atomic ? iter_state->adjust_for_activation_atomic(dst_pending_id, + increment_dead) + : iter_state->adjust_for_activation(dst_pending_id, + increment_dead); + dst_dead = adjust_result.dead_count > 0; + dst_ready = !(adjust_result.pending_count > 0); + } } maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 680cb13ef3ecb4..0f9099973ba641 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -261,6 +261,8 @@ class PropagatorState { // frame_name. uint64 frame_id; + int call_id = -1; + // The iteration state of its parent frame when this frame is created. // nullptr if there is no parent frame. The frame_name/parent_iter pair // uniquely identifies this FrameState. @@ -281,6 +283,14 @@ class PropagatorState { // The number of outstanding iterations. int num_outstanding_iterations TF_GUARDED_BY(mu) = 1; + // Mapping from frame ID to outstanding frames. A new frame is created + // at some iteration of an active frame. So the unique key for the new + // child frame is a hash composed of the ID of the parent frame, the iteration + // number at which the parent frame is creating the new frame, and the + // name of the new frame from nodedef. + absl::flat_hash_map outstanding_child_frames_ + TF_GUARDED_BY(mu); + private: // The active iteration states of this frame. gtl::InlinedVector iterations; @@ -538,14 +548,6 @@ class PropagatorState { // The root frame in which the execution of this step is started. FrameState* root_frame_; - // Mapping from frame ID to outstanding frames. A new frame is created - // at some iteration of an active frame. So the unique key for the new - // child frame is a hash composed of the ID of the parent frame, the iteration - // number at which the parent frame is creating the new frame, and the - // name of the new frame from nodedef. - absl::flat_hash_map outstanding_frames_ - TF_GUARDED_BY(mu_); - PropagatorState(const PropagatorState&) = delete; void operator=(const PropagatorState&) = delete; }; diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 5593963988d9e5..607d80378a06b1 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -24,6 +24,9 @@ limitations under the License. #include #include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + #include "absl/status/status.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/profile_handler.h" @@ -356,6 +359,37 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions( popts.flib_def = client_graph->flib_def.get(); Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs); if (s.ok()) { + + printf("\n\n MASTER PARTITIONS:\n"); + int i=0; + for (const auto& it: graph_defs) { + string dvc = it.first; + const GraphDef* graphDef = &it.second; + printf("\n\nDeviceName :'%s'\n", dvc.c_str()); + printf("Partition GraphDef:\n %s\n", SummarizeGraphDef(*graphDef).c_str()); + + string p = strings::StrCat("Partition", i); i++; + EventsWriter writer(p); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = graphDef->ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + graphDef->GetTypeName(), "' and size ", proto_size); + } + graphDef->SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + + } + + + // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain // valid after the call to DoRegisterPartitions begins, so // `stats_publisher_` must make a copy if it wants to retain the diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 10984ae23608bc..6c84e6f2b291ce 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -61,6 +61,8 @@ Node::NodeClass Node::GetNodeClassForOp(const std::string& ts) { REF_CLASS("Enter", NC_ENTER), REF_CLASS("Exit", NC_EXIT), REF_CLASS("NextIteration", NC_NEXT_ITERATION), + REF_CLASS("Call", NC_CALL), + REF_CLASS("Return", NC_RETURN), {"LoopCond", NC_LOOP_COND}, {"ControlTrigger", NC_CONTROL_TRIGGER}, {"_Send", NC_SEND}, diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index c7a4f696bf126d..60927b1543fcab 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -167,6 +167,8 @@ class Node { bool IsEnter() const { return class_ == NC_ENTER; } bool IsExit() const { return class_ == NC_EXIT; } bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; } + bool IsCall() const { return class_ == NC_CALL; } + bool IsReturn() const { return class_ == NC_RETURN; } bool IsLoopCond() const { return class_ == NC_LOOP_COND; } bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; } bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; } @@ -182,6 +184,7 @@ class Node { bool IsControlFlow() const { return (class_ != NC_OTHER) && // Fast path (IsSwitch() || IsMerge() || IsEnter() || IsExit() || + IsCall() || IsReturn() || IsNextIteration()); } bool IsHostSend() const { return class_ == NC_HOST_SEND; } @@ -313,6 +316,8 @@ class Node { NC_ENTER, NC_EXIT, NC_NEXT_ITERATION, + NC_CALL, + NC_RETURN, NC_LOOP_COND, NC_CONTROL_TRIGGER, NC_SEND, @@ -935,6 +940,8 @@ inline bool IsMerge(const Node* node) { return node->IsMerge(); } inline bool IsEnter(const Node* node) { return node->IsEnter(); } inline bool IsExit(const Node* node) { return node->IsExit(); } inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); } +inline bool IsCall(const Node* node) { return node->IsCall(); } +inline bool IsReturn(const Node* node) { return node->IsReturn(); } inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); } inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); } inline bool IsSend(const Node* node) { return node->IsSend(); } diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index eac0fa367e5577..1e7a5d0690de8e 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/graph/graph_partition.h" - #include #include #include @@ -43,6 +41,12 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/graph/graph_partition.h" +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" +#include "tensorflow/core/framework/graph_def_util.h" + namespace tensorflow { namespace { @@ -968,6 +972,656 @@ void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) { } } +/**************************************************************************************************/ + +struct StateMachineNodeInput { + string src; + int index; +}; + +struct StateMachineParent { + Node* parent_node; + int parent_index; +}; + +struct StateMachineNode { + Node* node; + std::vector inputs; +}; + +struct StateMachineGraph { + std::unordered_map nodes; + std::set depends_on; + Node* merge; +}; + +struct StateMachine { + // A map from unique_ids to StateMachineGraphs representing a general dynamic + // state machine that we update every time a function gets called, and helps us + // gradually build the state machines of the partitions + std::unordered_map state_machine_graphs; + // state_machine_parents is the 'spine' of the graph, + // containing only control flow nodes + std::vector state_machine_parents; + + std::unordered_map switches_info; + // + std::unordered_map switchToPred; + + string leader_partition; + + // Maps device names to smaller strings + std::unordered_map device_names_map; + + std::unordered_map*> partitionsToSMG; +}; + +struct FuncInfo { + // A map from to the num of function's arguments + std::unordered_map funcInputs; + // Helps us seperate functions with same frame_name but + // different non recursive call sites + std::unordered_map funcVisitedCounter; + // Εach vector below operates as a barrier, + // we don't call CallingFunction(..) before we gather + // all function's arguments/calls first + std::unordered_map*> funcCalls; +}; + +// Adds root nodes into ready_nodes queue and sets ready_inputs appropriately +Status PreprocessGraph(std::unordered_map &ready_inputs, Graph* g, + std::deque &ready_nodes) { + + + std::unordered_map> returning_nodes; + + for (Node* node : g->nodes()) { + + if (node->in_edges().empty()) { + ready_nodes.push_back(node); + } + bool recursion_merge = 0; + if (IsMerge(node)) { + ready_inputs[node] = 0; + for (const Edge* in_edge : node->in_edges()) { + + Node* in = in_edge->src(); + // if (IsNextIteration(*output_map.GetNode(input))) { + // ready_inputs[node]++; + // } + if (IsCall(in)) { + ready_inputs[node]++; + recursion_merge = 1; + } + } + if (recursion_merge) { + ready_inputs[node]--; + recursion_merge = 0; + } + + } else if (IsReturn(node)) { + + for (const Edge* in_edge : node->in_edges()) { + Node* in = in_edge->src(); + + if (!in_edge->IsControlEdge()) { + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "call_id", &call_id)); + returning_nodes[in].emplace(call_id); + } + } + ready_inputs[node] = 0; + + } else { + ready_inputs[node] = 0; + } + } + + for (const auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + ready_inputs[retnode.first]++; + } + } + + return OkStatus(); +} + +string GetDeviceMappedName(StateMachine &state_machine, string device_name) { + + std::unordered_map& device_map = state_machine.device_names_map; + + auto slot = &device_map[device_name]; + if (*slot == "") + *slot = strings::StrCat("_p", device_map.size() + 1); + return *slot; +} + +bool IsCallSuccessor(Node* node) { + + for (const Edge* in_edge : node->in_edges()) { + Node* src = in_edge->src(); + if (IsCall(src) && !in_edge->IsControlEdge()) + return true; + } + return false; +} + +void DeleteStateMachineGraph(StateMachine& state_machine, string unique_id) { + + StateMachineGraph *smg = state_machine.state_machine_graphs[unique_id]; + + for (auto& it : smg->nodes) + delete it.second; + delete smg; +} + +std::vector* GetOrCreateCalls(int call_id, std::unordered_map*> &funcCalls) { + auto slot = &funcCalls[call_id]; + if (*slot == nullptr) + *slot = new std::vector; + return *slot; +} + +std::set* GetOrCreatePartition(string partition, std::unordered_map*> &partsTpSmg) { + auto slot = &partsTpSmg[partition]; + if (*slot == nullptr) + *slot = new std::set; + return *slot; +} + +// For one if-else construction there are more than one Switch nodes guarding all the inputs +// that are needed inside the branches but live outside of them. We need to collect all the Switch +// nodes that correspond to one if-else construction and treat them as one in the state machines +// switches_info: Every switch node maps to the original switch that we "ll take into account +void CollectSwitches(Graph* g, StateMachine& state_machine) { + + std::unordered_map pred_switch; + + for (Node *node : g->nodes()) { + + if (IsSwitch(node)) { + + for (const Edge *in_edge : node->in_edges()) { + + int port = in_edge->dst_input(); + + // A sloppy way to determine if this is the predicate input + if (!in_edge->IsControlEdge() && port == 1) { + + Node *predicate = in_edge->src(); + + while (IsIdentity(predicate)) { + for (const Edge *inEdge : predicate->in_edges()) { + if (!inEdge->IsControlEdge()) { + predicate = inEdge->src(); + break; + } + } + } + + // We 've got the real predicate + Node *switchNode; + if (pred_switch.find(predicate) == pred_switch.end()) { + // Original switch + pred_switch[predicate] = node; + state_machine.switchToPred[node] = predicate; + switchNode = node; + } else { + // "Synonym" switch + switchNode = pred_switch[predicate]; + } + + state_machine.switches_info[node] = switchNode; + + break; + } + } + printf("Switch : %s -> %s\n", node->name().c_str(), state_machine.switches_info[node]->name().c_str()); + } + } + + printf("\n\n\n"); +} + +void GatherPartitionStateMachines(StateMachine& state_machine, std::set* smgs) { + + std::deque queue; + + for (auto& it : *smgs) + queue.push_back(it); + + while (!queue.empty()) { + string smg = queue.front(); + queue.pop_front(); + + StateMachineGraph* sm_graph = state_machine.state_machine_graphs[smg]; + for (auto& it : sm_graph->depends_on) { + // If not already visited + if (smgs->find(it) == smgs->end()) { + smgs->emplace(it); + queue.push_back(it); + } + } + } +} + +NodeDef* FindNodeInGraphDef(GraphDef& graphDef, string node_name) { + + for (NodeDef& nodeDef : *graphDef.mutable_node()) { + if (nodeDef.name() == node_name) + return &nodeDef; + } + return nullptr; +} + +void ConnectMergeToNode(GraphDef& graphDef, string merge_name, string node_name, + StateMachine& state_machine, string partition_name) { + + // We can safely infer the correct Merge's name and add it as control input to the node + // even though partition state machine's Merge has not already been added into graphdef + string suffix; + (partition_name != state_machine.leader_partition) ? + (suffix = GetDeviceMappedName(state_machine, partition_name)) : (suffix = ""); + + //Add as control input + NodeDef* node = FindNodeInGraphDef(graphDef, node_name); + *node->add_input() = strings::StrCat("^", merge_name, suffix); +} + +void AddPartitionStateMachine(StateMachine& state_machine, GraphDef& main_graphDef, + string unique_id, string partition) { + + StateMachineGraph *sm_graph = state_machine.state_machine_graphs[unique_id]; + string suffix = GetDeviceMappedName(state_machine, partition); + for (const auto &it : sm_graph->nodes) { + string node_name = it.first; + StateMachineNode *sm_node = it.second; + Node *node = sm_node->node; + + // Build NodeDef + NodeDef *nodedef = main_graphDef.add_node(); + //Note: suffix does not guarantee that name is unique + nodedef->set_name(strings::StrCat(node_name, suffix)); + nodedef->set_op(node->op_def().name()); + nodedef->set_device(partition); + + // Add Inputs + for (int i = 0; i < sm_node->inputs.size(); ++i) { + // There won't exist any control inputs here + nodedef->add_input(strings::StrCat(sm_node->inputs[i].src, suffix, ":", sm_node->inputs[i].index)); + + if (absl::StartsWith(StringPiece(sm_node->inputs[i].src),"Dummy_")) { + Tensor tensor(DT_INT32, TensorShape({0})); + NodeDef* dummy = main_graphDef.add_node(); + dummy->set_name(strings::StrCat(sm_node->inputs[i].src, suffix)); + dummy->set_op("Const"); + dummy->set_device(partition); + AddNodeAttr("dtype", DT_INT32, dummy); + AddNodeAttr("value", tensor, dummy); + } + } + + if (IsSwitch(node)) { + // Add predicate input too + nodedef->add_input(state_machine.switchToPred[node]->name()); + // Add control input from partition's Merge to partition's Switch + nodedef->add_input(strings::StrCat("^", sm_graph->merge->name(), suffix)); + } + + for (const auto &itt : node->def().attr()) { + // Not sure if this is copying attrs correctly + if (itt.first == "T") { + // We don't care about keeping the original "T" attr + // in state machine nodes + AddNodeAttr(itt.first, DT_INT32, nodedef); + } else + AddNodeAttr(itt.first, itt.second, nodedef); + } + } +} + +Status AddNodeToStateMachine(StateMachine& state_machine, string unique_id, Node* node, bool cycle) { + + StateMachineGraph *smg = state_machine.state_machine_graphs[unique_id]; + StateMachineNode *smn = new StateMachineNode; + + smn->node = node; + + StateMachineParent *parent = &state_machine.state_machine_parents[node->id()]; + + if (parent->parent_node == nullptr) { + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "call_id", &call_id)); + smn->inputs.push_back({strings::StrCat("Dummy_", call_id), 0}); + } else + smn->inputs.push_back({parent->parent_node->name(), parent->parent_index}); + + smg->nodes[node->name()] = smn; + + // If cycle is true, node is a recursive call, that needs to be added as + // input to the corresponding Merge node + if (cycle) { + // We traverse graph the way topological sort does, so we will never + // meet a recursive call node before its corresponding Merge + StateMachineNode* merge = smg->nodes[smg->merge->name()]; + merge->inputs.push_back({node->name(), 0}); + } + + return OkStatus(); +} + +Status CallingFunction(Graph* graph, GraphDef& main_graphDef, StateMachine& state_machine, FuncInfo& funcInfo, + string function_frame_name, int function_call_id, + std::unordered_map& ready_inputs, + std::deque& prev_ready_nodes) { + + Node *merge, *call; + std::deque ready_nodes; + + string function_unique_id = strings::StrCat(function_frame_name, ":", + funcInfo.funcVisitedCounter[function_frame_name]); + + std::vector* calls = funcInfo.funcCalls[function_call_id]; + for (int i=0; i < calls->size(); ++i) { + ready_nodes.push_back((*calls)[i]); + } + call = (*calls)[0]; + + // We add only one Call node for all possible function's args in the state machine + TF_RETURN_IF_ERROR(AddNodeToStateMachine(state_machine, function_unique_id, call, false)); + + std::vector& state_machine_parents = state_machine.state_machine_parents; + StateMachineGraph* sm_graph = state_machine.state_machine_graphs[function_unique_id]; + + // Call's successor (the non control output) will be either + // a Merge node (in case of recursion) or an Identity node. + // Either way we add that successor to the state machine, too. + // Same as above, we add only one Merge node instead of one per function's arg + for (const Edge* out_edge : call->out_edges()) { + if (!out_edge->IsControlEdge()) { + merge = out_edge->dst(); + state_machine_parents[merge->id()].parent_node = call; + state_machine_parents[merge->id()].parent_index = 0; + TF_RETURN_IF_ERROR(AddNodeToStateMachine(state_machine, function_unique_id, merge, false)); + sm_graph->merge = merge; + break; + } + } + + while (!ready_nodes.empty()) { + + Node* ready_node = ready_nodes.front(); + ready_nodes.pop_front(); + + int parent_index = 0; + Node* parent = state_machine_parents[ready_node->id()].parent_node; + + // The ops below need to update the parent + if (IsCall(ready_node)) { + parent = call; + } else if (IsCallSuccessor(ready_node)) { + parent = merge; + } else if (IsSwitch(ready_node)) { + Node *sw = state_machine.switches_info[ready_node]; + if (sw == ready_node) + TF_RETURN_IF_ERROR(AddNodeToStateMachine(state_machine, function_unique_id, ready_node, false)); + parent = sw; + } else if (IsMerge(ready_node)) { + // Control Flow (regular) Merge has a corresponding Switch node + // Parent gets the value of that switch node's parent + parent = state_machine_parents[parent->id()].parent_node; + parent_index = state_machine_parents[parent->id()].parent_index; + } else if (IsReturn(ready_node)) { + // Return needs to propagate its corresponding Call's parent to all its successors + for (const Edge* in_edge : ready_node->in_edges()) { + if (in_edge->IsControlEdge()) { + Node* call_node = in_edge->src(); + parent = state_machine_parents[call_node->id()].parent_node; + parent_index = state_machine_parents[call_node->id()].parent_index; + break; + } + } + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(ready_node->attrs(), "call_id", &call_id)); + // If not a 'recursive' return + if (call_id == function_call_id) { + // Add the successors of Return node to prev_ready_nodes queue + prev_ready_nodes.push_back(ready_node); + // Set the parent value of the only actual output of return + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + state_machine_parents[out->id()].parent_node = parent; + state_machine_parents[out->id()].parent_index = parent_index; + break; + } + continue; + } + } + + // Process ready_node's outputs + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + + ready_inputs[out]++; + + // For a cross-device edge, on the dst device, add a control edge + // from the merge node of the state machine to dst. If a send/recv is + // introduced for this edge in future partitioning, we delete this + // control edge and add a new control edge from the merge to the recv. + const string& src_device = ready_node->assigned_device_name(); + const string& dst_device = out->assigned_device_name(); + if (src_device != dst_device) { + if (IsCallSuccessor(ready_node) && IsConstant(out)) { + // Remove this control edge that ensures constant executes in the same frame, + // and add a new one from the Constant's partition's state machine merge to the constant + NodeDef* con_node = FindNodeInGraphDef(main_graphDef, out->name()); + for (string& input : *con_node->mutable_input()) { + if (absl::StartsWith(StringPiece(input),strings::StrCat("^", ready_node->name()))) { + string suffix = GetDeviceMappedName(state_machine, dst_device); + input = strings::StrCat("^", merge->name(), suffix); + break; + } + } + } else + if(merge->name() != out->name()) ConnectMergeToNode(main_graphDef, merge->name(), out->name(), state_machine, dst_device); + } + + if (ready_inputs[out] == out->in_edges().size()) { + + if (IsSwitch(ready_node)) { + // We need to fix parent_index appropriately + parent_index = out_edge->src_output(); + } + + // Set node's parent + state_machine_parents[out->id()].parent_node = parent; + state_machine_parents[out->id()].parent_index = parent_index; + + std::unordered_map& sm_graphs = state_machine.state_machine_graphs; + + if (IsCall(out)) { + + string frame_name; + TF_RETURN_IF_ERROR(GetNodeAttr(out->attrs(), "frame_name", &frame_name)); + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(out->attrs(), "call_id", &call_id)); + + std::vector* calls = GetOrCreateCalls(call_id, funcInfo.funcCalls); + calls->push_back(out); + + if (funcInfo.funcInputs[frame_name] == calls->size()) { + + // We gathered all function's inputs + + string unique_id = strings::StrCat(frame_name, ":", funcInfo.funcVisitedCounter[frame_name]); + + if (sm_graphs.find(unique_id) == sm_graphs.end()) { + + sm_graphs.emplace(unique_id, new StateMachineGraph); + TF_RETURN_IF_ERROR(CallingFunction(graph, main_graphDef, state_machine, funcInfo, frame_name, call_id, ready_inputs, ready_nodes)); + funcInfo.funcVisitedCounter[frame_name]++; + } else { + // Recursive Call (either to the same function or another one (mutual recursion) + TF_RETURN_IF_ERROR(AddNodeToStateMachine(state_machine, unique_id, (*calls)[0], true)); + // Add the recursive call nodes to ready_nodes + for (int i=0; i < calls->size(); ++i) + ready_nodes.push_back((*calls)[i]); + } + + sm_graphs[unique_id]->depends_on.emplace(function_unique_id); + } + } else { + GetOrCreatePartition(dst_device, state_machine.partitionsToSMG)->emplace(function_unique_id); + ready_nodes.push_back(out); + } + } + } + } + + return OkStatus(); +} + +Status AddFunctionStateMachines(const PartitionOptions& opts, + Graph* g, GraphDef& main_graphDef, GraphInfo* g_info) { + + Status status; + GraphDefBuilder::Options bopts(g, &status); + + FuncInfo funcInfo; + int nodes_num = g->num_node_ids(); + + + const FunctionLibraryDefinition* flib_def = opts.flib_def; + if(flib_def == nullptr){ + flib_def = &(g->flib_def()); + } + + const FunctionDefLibrary& fdef = flib_def->ToProto(); + + + + for (const FunctionDef& func : fdef.function()) { + + int num_inputs = func.signature().input_arg_size(); + string name = func.signature().name(); + funcInfo.funcInputs[name] = num_inputs; + funcInfo.funcVisitedCounter[name] = 0; + } + + StateMachine state_machine; + state_machine.state_machine_parents.resize(nodes_num); + + CollectSwitches(g, state_machine); + + // Add all state machines for cross-device frames. + // A state machine is added only when there is a cross-device edge in a + // non-root frame. + + // Visit nodes the way topological sort does + std::deque ready_nodes; + std::unordered_map ready_inputs; + + TF_RETURN_IF_ERROR(PreprocessGraph(ready_inputs, g, ready_nodes)); + + // We convert graph to its equivalent graph_def, cause it's easier + // to extend it with the GraphDef state machines of partitions + g->ToGraphDef(&main_graphDef); + + while (!ready_nodes.empty()) { + Node* ready_node = ready_nodes.front(); + ready_nodes.pop_front(); + + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + + ready_inputs[out]++; + + if (ready_inputs[out] == out->in_edges().size()) { + + if (IsCall(out)) { + string frame_name; + TF_RETURN_IF_ERROR(GetNodeAttr(out->attrs(), "frame_name", &frame_name)); + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(out->attrs(), "call_id", &call_id)); + + std::vector* calls = GetOrCreateCalls(call_id, funcInfo.funcCalls); + calls->push_back(out); + + if (funcInfo.funcInputs[frame_name] == calls->size()) { + + string unique_id = strings::StrCat(frame_name, ":", funcInfo.funcVisitedCounter[frame_name]); + + // We gathered all function's inputs + state_machine.leader_partition = out->assigned_device_name(); + state_machine.state_machine_graphs.emplace(unique_id, new StateMachineGraph); + TF_RETURN_IF_ERROR(CallingFunction(g, main_graphDef, state_machine, funcInfo, frame_name, call_id, ready_inputs, ready_nodes)); + funcInfo.funcVisitedCounter[frame_name]++; + + // Adding partition state machines to graph + for (auto& it: state_machine.partitionsToSMG) { + string partition = it.first; + + // Leader Partition already has its state machine + if (partition == state_machine.leader_partition) + continue; + + std::set* smgs = it.second; + + // Collect all the state machine graphs that smgs depened on + GatherPartitionStateMachines(state_machine, smgs); + + for (auto& it : *smgs) + AddPartitionStateMachine(state_machine, main_graphDef, it, partition); + } + + // Deallocate space + for (auto& it : state_machine.partitionsToSMG) + delete it.second; + state_machine.partitionsToSMG.clear(); + + for (auto& it: state_machine.state_machine_graphs) + DeleteStateMachineGraph(state_machine, it.first); + state_machine.state_machine_graphs.clear(); + } + } else + ready_nodes.push_back(out); + } + } + } + + // Deallocate space + for (auto& it : funcInfo.funcCalls) + delete it.second; + +/****************************************************************************/ + printf("\n\nSummarize Main Graph\n %s\n", SummarizeGraphDef(main_graphDef).c_str()); + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("Full_Partitioned"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = main_graphDef.ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + main_graphDef.GetTypeName(), "' and size ", proto_size); + } + main_graphDef.SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); +/****************************************************************************/ + + return OkStatus(); +} + + + +/**************************************************************************************************/ + Status Partition(const PartitionOptions& opts, Graph* g, std::unordered_map* partitions) { // TODO(b/290689453) Refactor this into smaller functions @@ -977,14 +1631,40 @@ Status Partition(const PartitionOptions& opts, Graph* g, partitions->clear(); GraphInfo g_info; + std::unique_ptr new_g(new Graph(OpRegistry::Global())); if (!opts.control_flow_added) { // Add the "code" for distributed execution of control flow. Code is // added only for the frames that are placed on multiple devices. The // new graph is an equivalent transformation of the original graph and // has the property that it can be subsequently partitioned arbitrarily // (down to the level of individual device) for distributed execution. + + GraphDef main_graphDef; + g->ToGraphDef(&main_graphDef); + printf("\n\nSummarize Main Graph:\n %s\n\n", SummarizeGraphDef(main_graphDef).c_str()); + status = AddControlFlow(opts, g, &g_info); if (!status.ok()) return status; + + GraphDef gdef; + status = AddFunctionStateMachines(opts, g, gdef, &g_info); + if (status.ok()) { + // Convert GraphDef back to Graph so it can be partitioned + GraphConstructorOptions gopts; + gopts.allow_internal_ops = true; + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(gopts, gdef, new_g.get())); + g = new_g.get(); + + // The graph conversion sets the requested device names but not the assigned + // device names. However, since at this point the graph is placed TF expects + // an assigned device name for every node. Therefore we copy the requested + // device into the assigned device field. + for (Node* node : g->nodes()) { + node->set_assigned_device_name(node->requested_device()); + } + } else return status; + } // At this point, all the graph mutations have been done. Build memory @@ -1058,7 +1738,20 @@ Status Partition(const PartitionOptions& opts, Graph* g, int32_t num_input_edges = 0; for (const Edge* edge : dst->in_edges()) { if (edge->IsControlEdge()) { - if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { + if ((IsMerge(edge->src()) && IsControlLoop(edge->src())) || + (IsCallSuccessor(edge->src()) && (!IsConstant(edge->dst()) || + edge->dst()->in_edges().size() > 1))) { + // Note: not all control edges are control flow edges. + // There are also control edges added in + // FunctionTransformation for ensuring that Constants will execute in the + // correct 'frame'. + // We made sure in AddFunctionsStateMachines that: + // if a Constant in partition A, has such incoming edge from a CallSuccessor(..) + // node, then this node will definitely belong in the same A partition, so we + // can safely add those edges in "inputs" as we do with common control edges. + // All the other edges whose src node is a CallSuccessor node are control flow edges. + + // This is one of the control edges added for control flow. There // can be multiple such edges as the dest node may have multiple // remote inputs. We keep track of the number of such edges. diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index e0981fe90c8ae9..46b04a24f41461 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -150,6 +150,11 @@ bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; } bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; } +bool IsCall(const NodeDef& node) { + const auto& op = node.op(); + return op == "Call" || op == "RefCall"; +} + bool IsConcat(const NodeDef& node) { return node.op() == "Concat" || node.op() == "ConcatV2"; } @@ -498,6 +503,11 @@ bool IsRetval(const NodeDef& node) { return node.op() == "_Retval" || node.op() == "_DeviceRetval"; } +bool IsReturn(const NodeDef& node) { + const auto& op = node.op(); + return op == "Return" || op == "RefReturn"; +} + bool IsReverse(const NodeDef& node) { return node.op() == "Reverse" || node.op() == "ReverseV2"; } @@ -777,7 +787,7 @@ bool ModifiesInputsInPlace(const NodeDef& node) { } bool ModifiesFrameInfo(const NodeDef& node) { - return IsEnter(node) || IsExit(node) || IsNextIteration(node); + return IsEnter(node) || IsExit(node) || IsNextIteration(node) || IsCall(node) || IsReturn(node); } #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \ diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index c233b6e9c6b61a..9546e948136177 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -55,6 +55,7 @@ bool IsCheckNumerics(const NodeDef& node); bool IsCollective(const NodeDef& node); bool IsComplex(const NodeDef& node); bool IsComplexAbs(const NodeDef& node); +bool IsCall(const NodeDef& node); bool IsConcat(const NodeDef& node); bool IsConcatOffset(const NodeDef& node); bool IsConj(const NodeDef& node); @@ -163,6 +164,7 @@ bool IsRsqrt(const NodeDef& node); bool IsRsqrtGrad(const NodeDef& node); bool IsSelect(const NodeDef& node); bool IsSeluGrad(const NodeDef& node); +bool IsReturn(const NodeDef& node); bool IsSend(const NodeDef& node); bool IsShape(const NodeDef& node); bool IsShapeN(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index e967c46836756d..55ad6ade3326df 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -445,6 +445,47 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "function_transformation", + srcs = ["function_transformation.cc"], + hdrs = [ + "function_transformation.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + ":function_optimizer", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:functions", + ], +) + +tf_cc_test( + name = "function_transformation_test", + srcs = ["function_transformation_test.cc"], + shard_count = 5, + deps = [ + ":function_transformation", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + ], +) + cc_library( name = "model_pruner", srcs = ["model_pruner.cc"], @@ -642,6 +683,7 @@ cc_library( ":dependency_optimizer", ":function_optimizer", ":generic_layout_optimizer", + ":function_transformation", ":graph_optimizer", ":implementation_selector", ":loop_optimizer", diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc new file mode 100644 index 00000000000000..1cc92780cef45c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -0,0 +1,970 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_transformation.h" +#include +#include +#include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { +namespace grappler { +namespace { + +static constexpr const char* const kCallOp = "Call"; +static constexpr const char* const kRetOp = "Return"; +static constexpr const char* const kIdentityOp = "Identity"; +static constexpr const char* const kIdentityNOp = "IdentityN"; +static constexpr const char* const kMergeOp = "Merge"; +static constexpr const char* const kGradientOp = + FunctionLibraryDefinition::kGradientOp; +static constexpr const char* const kFuncAttrName = + FunctionLibraryDefinition::kFuncAttr; +static constexpr const char* kNoInlineAttr = "_noinline"; + +bool AttrIsTrue(const FunctionDef& func, const string& attr) { + return func.attr().count(attr) != 0 && func.attr().at(attr).b(); +} + +bool MarkedNoInline(const FunctionDef& func) { + return AttrIsTrue(func, kNoInlineAttr); +} + +// There are two ways of calling a Tensorflow function: +// +// 1. Direct function call: node.op() is the name of the function. +// +// 2. Indirect function call: the function name is passed through a node +// attribute, and special Tensorflow kernels are responsible for calling the +// function through the FunctionLibraryRuntime. Example: PartitionedCallOp. + +// Check if func_node.op() matches the name in FunctionDef signature. +bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) { + return func_node.op() == func.signature().name(); +} + +// Check if func_node has function attribute with a function name matching +// FunctionDef signature. +bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) { + auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName); + return func_attr != nullptr && func_attr->has_func() && + func_attr->func().name() == func.signature().name(); +} + +// Copy input/output argument type to the type_list. Return error if argument +// type is not explicitly defined, and not specified in function attributes. +Status CopyArgType(const OpDef::ArgDef& arg, + const AttrSlice& func_attr, + DataType* type) { + if (arg.type() != DT_INVALID) { + *type = arg.type(); + } else { + const AttrValue* it = func_attr.Find(arg.type_attr()); + if (it == nullptr || it->type() == DT_INVALID) { + return errors::InvalidArgument( + "Invalid argument ", arg.name()); + } + *type = it->type(); + } + return OkStatus(); +} + +// Copy input/output argument type to the type_list. Return error if argument +// type is not explicitly defined, and not specified in function attributes. +Status CopyArgType(const OpDef::ArgDef& arg, + const AttrSlice& func_attr, + AttrValue::ListValue* type_list) { + if (arg.type() != DT_INVALID) { + type_list->add_type(arg.type()); + } else { + const AttrValue* it = func_attr.Find(arg.type_attr()); + if (it == nullptr || it->type() == DT_INVALID) { + return errors::InvalidArgument("Invalid argument ", arg.name()); + } + type_list->add_type(it->type()); + } + return OkStatus(); +} + + +AttrSlice FunctionInstantiationAttributes(const FunctionDef& func, + const NodeDef& func_node) { + if (IsDirectFunctionCall(func, func_node)) { + return AttrSlice(func_node); + + } else if (IsIndirectFunctionCall(func, func_node)) { + auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName); + return AttrSlice(&func_attr->func().attr()); + + } else { + LOG(WARNING) << "Can't resolve function instantiation attributes: " + << SummarizeNodeDef(func_node); + return AttrSlice(); + } +} + +struct FuncInfo { + DataTypeVector arg_types; + DataTypeVector ret_types; + std::vector args; + std::vector rets; +}; + +struct FuncGradInfo { + FuncInfo f; + FuncInfo g; +}; + +// same with commit a9a3b98 (possibly) +class FunctionInliningContext { + public: + explicit FunctionInliningContext(const GrapplerItem& item) + : function_library_(FunctionLibraryDefinition(OpRegistry::Global(), + item.graph.library())) { + InitializeInlinedFunctions(item); + InitializeFetchNodes(item); + } + + + const FunctionLibraryDefinition& FunctionLibrary() const { + return function_library_; + } + + Status AddFunctionDef(const FunctionDef& fdef) { + TF_RETURN_IF_ERROR(function_library_.AddFunctionDef(fdef)); + inlined_functions_[fdef.signature().name()] = function_library_.Find(fdef.signature().name()); + return OkStatus(); + } + + + bool HasInlinedFunctions() const { return !inlined_functions_.empty(); } + + bool IsInlinedFunction(const string& name) const { + return inlined_functions_.count(name) > 0; + } + + // Find inlining candidate by name. Return nullptr if not found. + const FunctionDef* FindInlinedFunction(const string& name) const { + return gtl::FindWithDefault(inlined_functions_, name, nullptr); + } + + bool IsFetchNode(const string& node_name) const { + return fetch_nodes_.find(node_name) != fetch_nodes_.end(); + } + + const FunctionDef* FindInlinedFunctionAndGradient(const string& name) const { + string grad_name = strings::StrCat(name, "Grad"); + return FindInlinedFunction(grad_name); + } + + private: + void InitializeInlinedFunctions(const GrapplerItem& item) { + for (const FunctionDef& func : item.graph.library().function()) { + + printf("Func name %s\n",func.signature().name().c_str()); + + bool marked_noinline = MarkedNoInline(func); + // Don't inline functions marked as noinline + if (marked_noinline) { + continue; + } + // Don't touch anything marked XLA to prevent XLA failures further down + // the road. + if (func.attr().count("_XlaCompile") > 0 && + func.attr().at("_XlaCompile").b()) { + continue; + } + // Can't create IdentityN nodes with no input or output: skip these + // functions for now. + if (func.signature().input_arg_size() == 0 || + func.signature().output_arg_size() == 0) { + continue; + } + inlined_functions_[func.signature().name()] = &func; + } + } + + void InitializeFetchNodes(const GrapplerItem& item) { + for (const string& fetch : item.fetch) { + fetch_tensors_.insert(fetch); + fetch_nodes_.insert(NodeName(fetch)); + } + } + + FunctionLibraryDefinition function_library_; + std::unordered_map inlined_functions_; + gtl::FlatSet fetch_tensors_; // format: node_name:port + gtl::FlatSet fetch_nodes_; // format: node_name + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext); +}; + +struct CallInfo { + int call_id; + string call_frame; + NodeDef* fcall = nullptr; + NodeDef* gcall = nullptr; + bool hasGradient() const { return (gcall != nullptr); } +}; + +struct TransformationResult { + int call_id; + string call_frame; + NodeDef* transformed_node; + std::vector call_nodes; + std::vector ret_nodes; +}; + +class CallRewriter { + + public: + explicit CallRewriter(const GrapplerItem& item_, GraphDef* graph_, FunctionInliningContext& ctx_) + : graph(graph_), ctx(ctx_), item(item_) { } + + ~CallRewriter() { + Flush(); + } + + Status CollectCalls(std::vector& calls); + + Status TransformCall(const CallInfo& call_info); + + // Inlines a function to item.graph and if already inlined provide func_info + Status FindCompatibleOrInlineFunction(const CallInfo& call, + GraphDef* optimized_graph, + FuncGradInfo& func_info); + + void Flush(); + + inline int GetCallId(const NodeDef& node) { int call_id = id; id++; return call_id; } + + private: + Status TransformNode(const CallInfo& info, + NodeDef* call, const FuncInfo& f, + std::vector& call_nodes, + std::vector& ret_nodes, + bool is_gradient_node); + + void ReplaceOutput(const string& old_output, const string& new_output) { + // maybe some more checks + output_map_[old_output] = new_output; + } + + void MarkCallTransformed(const CallInfo& call_info) { + CHECK_NOTNULL(call_info.fcall); + MarkNodeDelete(call_info.fcall); + + if (call_info.gcall != nullptr) { + MarkNodeDelete(call_info.gcall); + } + } + + void MarkTransformed(TransformationResult& result) { + NodeDef* n = result.transformed_node; + CHECK_NOTNULL(n); + transformed_calls_[result.transformed_node->name()] = result; + n->clear_input(); + n->set_op("NoOp"); + n->set_name(AddPrefixToNodeName(n->name(), "$MarkToDelete$")); + nodes_to_delete.insert(n->name()); + } + + void MarkNodeDelete(NodeDef* n) { + n->clear_input(); + n->set_op("NoOp"); + n->set_name(AddPrefixToNodeName(n->name(), "$MarkToDelete$")); + nodes_to_delete.insert(n->name()); + } + + GraphDef* graph; + FunctionInliningContext& ctx; + const GrapplerItem& item; + std::unordered_map transformed_functions_; + std::unordered_map output_map_; + std::unordered_map transformed_calls_; + std::set nodes_to_delete; + int id = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(CallRewriter); +}; + +Status AddCallOp(const CallInfo& call_info, + const DataType& type, + const string& input, + const string& prefix, + int arg_id, NodeDef* call, bool is_gradient_call = false) { + string call_name = strings::StrCat("Call", "_", arg_id); + call->set_op(kCallOp); + call->set_name(AddPrefixToNodeName(call_name, prefix)); + //call->set_device(node.device()); + call->add_input(input); + + auto& attr = *call->mutable_attr(); + attr["T"].set_type(type); + attr["frame_name"].set_s(call_info.call_frame); + attr["call_id"].set_i(call_info.call_id); + attr["arg_id"].set_i(arg_id); + attr["is_constant"].set_b(false); + attr["is_gradient"].set_b(is_gradient_call); + + return OkStatus(); +} + +Status AddRetOp(const CallInfo& call_info, + const DataType& type, + const string& input, + const string& prefix, + int arg_id, NodeDef* ret, bool is_gradient_return = false) { + string ret_name = strings::StrCat("Ret", "_", arg_id); + ret->set_op(kRetOp); + ret->set_name(AddPrefixToNodeName(ret_name, prefix)); + ret->add_input(input); + + auto& attr = *ret->mutable_attr(); + attr["T"].set_type(type); + attr["frame_name"].set_s(call_info.call_frame); + attr["call_id"].set_i(call_info.call_id); + attr["arg_id"].set_i(arg_id); + attr["is_gradient"].set_b(is_gradient_return); + + return OkStatus(); +} + +Status ConnectInput(NodeDef* from, NodeDef* to) { + int to_input = to->input_size(); + if (to_input == 1) { + // it is Identity and we convert it to Merge. + CHECK(IsIdentity(*to)); + to->set_op(kMergeOp); + } + to->add_input(from->name()); + if (to->input_size() > 1) { + (*to->mutable_attr())["N"].set_i(to->input_size()); + } + return OkStatus(); +} + +Status InlineFunction(const FunctionDef& func_def, + const AttrSlice& func_instantiation_attr, + const FunctionInliningContext& ctx, + const string& device, + GraphDef* graph, FuncGradInfo& func_info) { + GrapplerFunctionItem item; + const int graph_version = graph->versions().producer(); + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func_def, func_instantiation_attr, ctx.FunctionLibrary(), graph_version, &item)); + + string prefix = func_def.signature().name(); + int arg_size = func_def.signature().input_arg_size(); + // create an inverse map of arg to provide name -> argument number + std::unordered_map input_nodes; + for (int i = 0; i < arg_size; ++i) { + const OpDef::ArgDef& input_arg = func_def.signature().input_arg(i); + input_nodes[input_arg.name()] = i; + } + func_info.f.args.resize(arg_size); + func_info.f.arg_types.resize(arg_size); + for (int i = 0; i < arg_size; ++i) { + const OpDef::ArgDef& input_arg = func_def.signature().input_arg(i); + NodeDef* merge = graph->add_node(); + merge->set_name(AddPrefixToNodeName(strings::StrCat("Input", "_", i), prefix)); + merge->set_op(kIdentityOp); + merge->set_device(device); + + DataType type; + TF_RETURN_IF_ERROR(CopyArgType(input_arg, func_instantiation_attr, &type)); + auto& attr = *merge->mutable_attr(); + attr["T"].set_type(type); + + func_info.f.args[i] = merge; + func_info.f.arg_types[i] = type; + } + + // prefix each node in function graph and place it to the global graph. + // the inputs of each node need to be renamed as well to reflect the change. + for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) { + const string& curr_name = func_body_node.name(); + // If the func body node is func's input argument + auto input_it = input_nodes.find(curr_name); + + if (input_it != input_nodes.end()) { + CHECK_EQ(0, func_body_node.input_size()); + // If the func body node is func's input argument + // Turn input placeholders into identity nodes + func_body_node.set_op(kIdentityOp); + // Connect merge with input arg + int idx = input_nodes[curr_name]; + func_body_node.add_input(func_info.f.args[idx]->name()); + } else { + // Else if not an input_arg_node + // Update the input names if any. + for (string& input : *func_body_node.mutable_input()) { + input = AddPrefixToNodeName(input, prefix); + } + // If this is a return node, change the op to KIdentityOp + if(IsRetval(func_body_node)){ + func_body_node.set_op(kIdentityOp); + } + + // If the node has no input, make hook it up to the Merge nodes to ensure + // it runs in the same frame as the other nodes of the function body. + if (func_body_node.input_size() == 0) { + for (auto& func_input_node : func_info.f.args) { + *func_body_node.add_input() = AsControlDependency(func_input_node->name()); + } + } + } + + // Add the node name as a prefix to avoid collisions after inlining + func_body_node.set_name(AddPrefixToNodeName(curr_name, prefix)); + + // Make sure the node is placed + if (func_body_node.device().empty()) + func_body_node.set_device(device); + + // Move the node to the main graph + graph->add_node()->Swap(&func_body_node); + } + + func_info.f.rets.clear(); + func_info.f.rets.resize(item.fetch.size()); + func_info.f.ret_types.resize(item.fetch.size()); + + std::vector fetch = item.fetch; + for (unsigned int i = 0; i < fetch.size(); i++) { + const OutputArgInstantiation& output_arg = item.output(i); + func_info.f.rets[i] = AddPrefixToNodeName(output_arg.node_name, prefix); + func_info.f.ret_types[i] = output_arg.data_type; + } + + return OkStatus(); +} + +Status InlineFunctionAndGradient(const FunctionDef& fdef, + const AttrSlice& func_instantiation_attr, + FunctionInliningContext& ctx, + const string& device, + GraphDef* graph, + FuncGradInfo& func_info) { + // Get func_def's gradient graph + + const FunctionDef* fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); + if (fgdef == nullptr) { + return errors::InvalidArgument( + "Invalid argument, gradient of function ", fdef.signature().name(), "can not be found", + "or not marked to be inlined"); + } + + + + + GrapplerFunctionItem item; + const int graph_version = graph->versions().producer(); + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(*fgdef, func_instantiation_attr, ctx.FunctionLibrary(), graph_version, &item)); + + string prefix = fdef.signature().name(); + size_t farg_size = fdef.signature().input_arg_size(); + size_t fret_size = fdef.signature().output_arg_size(); + size_t garg_size = fgdef->signature().input_arg_size();// - farg_size; + size_t gret_size = fgdef->signature().output_arg_size();// - fret_size; + + CHECK_EQ(farg_size, gret_size - fret_size); + CHECK_EQ(garg_size, fret_size + farg_size); + + func_info.f.arg_types.resize(farg_size); + func_info.g.arg_types.resize(garg_size); + func_info.g.ret_types.resize(gret_size); + for (int i = 0; i < farg_size; i++) { + const OpDef::ArgDef& input_arg = fdef.signature().input_arg(i); + func_info.f.arg_types[i] = input_arg.type(); + func_info.g.arg_types[i] = input_arg.type(); + } + + func_info.f.ret_types.resize(fret_size); + for (int i = 0; i < gret_size; i++) { + // const OutputArgInstantiation& output_arg = item.output(i); + if(i < fret_size){ + func_info.f.ret_types[i] = item.output(i).data_type; + func_info.g.arg_types[farg_size + i] = item.output(i).data_type; + } + func_info.g.ret_types[i] = item.output(i).data_type; + } + + // create an inverse map of arg to provide name -> argument number + std::unordered_map input_map; + std::vector input_names; + input_names.resize(farg_size); + for (int i = 0; i < garg_size; ++i) { + input_map[item.input(i).node_name] = i; + if (i < farg_size) { + input_names[i] = item.input(i).node_name; + } + } + func_info.f.args.resize(farg_size); + func_info.f.rets.resize(fret_size); + func_info.g.args.resize(garg_size); + func_info.g.rets.resize(gret_size); + + // prefix each node in function graph and place it to the global graph. + // the inputs of each node need to be renamed as well to reflect the change. + for (NodeDef& n : *item.mutable_function_body().mutable_node()) { + // If the func body node is func's input argument + auto input_it = input_map.find(n.name()); + bool is_input = input_it != input_map.end(); + + if (is_input) { + CHECK_EQ(0, n.input_size()); + n.set_op(kIdentityOp); + } + + // Add the node name as a prefix to avoid collisions after inlining + n.set_name(AddPrefixToNodeName(n.name(), prefix)); + // Update the input names if any. + for (string& input : *n.mutable_input()) { + input = AddPrefixToNodeName(input, prefix); + } + + // Make sure the node is placed + if (n.device().empty()) + n.set_device(device); + + // if (n.op() == kGradientOp) { + // auto& attr = *n.mutable_attr(); + // std::string& name = *attr.at("f").mutable_func()->mutable_name(); + // name = AddPrefixToNodeName(name, prefix); + // } + if(IsRetval(n)){ + n.set_op(kIdentityOp); + } + + // If the node has no input, make hook it up to the Merge nodes to ensure + // it runs in the same frame as the other nodes of the function body. + if (!is_input && n.input_size() == 0) { + // CHECK: constants from both in function and gradient are connected + // with the inputs of the function only. + for (const string& arg : input_names) { + *n.add_input() = AsControlDependency(AddPrefixToNodeName(arg, prefix)); + } + } + + // Move the node to the main graph + NodeDef* nn = graph->add_node(); + nn->Swap(&n); + + if (is_input) { + int i = input_it->second; + if (i < farg_size) { + func_info.f.args[i] = nn; + func_info.g.args[i] = func_info.f.args[i]; + } else { + func_info.g.args[i] = nn; + } + } + } + + CHECK_EQ(gret_size, item.fetch.size()); + + for (unsigned int i = 0; i < gret_size; i++) { + string output_port = AddPrefixToNodeName(item.output(i).node_name, prefix); + if (i < fret_size) { + func_info.f.rets[i] = output_port; + } + func_info.g.rets[i] = output_port; + } + + return OkStatus(); +} + +Status CallRewriter::CollectCalls(std::vector& calls) { + + std::unordered_map call_map; + std::vector gradients; + + // identify and collect calls in the graph + for (NodeDef& node : *graph->mutable_node()) { + if (node.op() == kGradientOp) { + gradients.push_back(&node); + } else { + const FunctionDef* func_def = ctx.FindInlinedFunction(node.op()); + if (func_def != nullptr) { + CallInfo& call = call_map[node.op()]; + call.call_id = GetCallId(node); + call.call_frame = node.op(); + call.fcall = &node; + } + } + } + for (NodeDef* gcall : gradients) { + if (gcall->attr().count("f") > 0) { + printf("Debug string: %s \n\n", gcall->attr().at("f").DebugString().c_str()); + const string& n = gcall->attr().at("f").func().name(); + + auto fcall_it = call_map.find(n); + if (fcall_it == call_map.end()) { + // return errors::InvalidArgument("Cannot find forward node for gradient ", + // gcall->name()); + continue; + } + CallInfo& call = fcall_it->second; + call.gcall = gcall; + } + } + + for (const auto& it : call_map) { + calls.push_back(it.second); + } + return OkStatus(); +} + +Status CallRewriter::TransformNode(const CallInfo& info, + NodeDef* call, + const FuncInfo& f, + std::vector& call_nodes, + std::vector& ret_nodes, bool is_gradient_node = false) { + CHECK_EQ(call->input_size(), f.args.size()); + + unsigned int next_return_node = is_gradient_node ? ret_nodes.size() : 0; + + call_nodes.resize(f.args.size()); + for (unsigned int i = 0; i < f.args.size(); i++) { + /* check if call node is already in place, if so, validate and skip */ + if (call_nodes[i] != nullptr) { + // TODO: validate call_id + // TODO: validate input + //CHECK_EQ(call_nodes[i]->input(0), call->input(i)); + } else { + call_nodes[i] = graph->add_node(); + TF_CHECK_OK(AddCallOp(info, + f.arg_types[i], + call->input(i), + call->name(), + i, + call_nodes[i], + is_gradient_node)); + + call_nodes[i]->set_device(call->device()); + + // connect the input of the inlined function to feed from call. + TF_RETURN_IF_ERROR(ConnectInput(call_nodes[i], f.args[i])); + } + } + + // check for control edges in call + gtl::FlatSet control_inputs; + for (const string& input : call->input()) { + if (IsControlInput(input)) { + control_inputs.insert(NodeName(input)); + } + } + + for (NodeDef* call_node : call_nodes) { + for (const string& control_input : control_inputs) + *(call_node->add_input()) = AsControlDependency(control_input); + } + + ret_nodes.resize(f.rets.size()); + for (unsigned int i = 0; i < f.rets.size(); i++) { + if (ret_nodes[i] != nullptr) { + // TODO: validate call_id + // CHECK_EQ(ret_nodes[i]->input(0), f.rets[i]); + } else { + ret_nodes[i] = graph->add_node(); + TF_CHECK_OK(AddRetOp(info, + f.ret_types[i], + f.rets[i], + call->name(), + i, + ret_nodes[i], + is_gradient_node)); + ret_nodes[i]->set_device(call->device()); + } + } + + if (ctx.IsFetchNode(call->name())) { + // create an IdentityN with the same name of the initial function call + // so as to preserve the naming of the outputs. + // we re-use the initial node and we change (a) the op to IdentityN and + // (b) the inputs to point to the outputs of the ret_nodes + // The other information such as types, device placement etc remain the same. + // The IdentityN node will sync the outputs and therefore may result to performance degradation. + NodeDef* out = graph->add_node(); + out->set_op(kIdentityNOp); + out->set_name(call->name()); + out->set_device(call->device()); + AttrValue::ListValue* type_list = (*out->mutable_attr())["T"].mutable_list(); + for (const DataType& type : f.ret_types) { + type_list->add_type(type); + } + for (unsigned int i = 0; i < f.rets.size(); i++) { + *out->add_input() = ret_nodes[i]->name(); + } + } else { + for (unsigned int i = next_return_node; i < f.rets.size(); i++) { + ReplaceOutput(strings::StrCat(call->name(), ":", i - next_return_node), ret_nodes[i]->name()); + if(i == next_return_node)ReplaceOutput(call->name(), ret_nodes[i]->name()); + } + } + + // for each call create a control dependency to each return + // to facilitate dead propagation semantics + for (NodeDef* ret : ret_nodes) { + for (NodeDef* call : call_nodes){ + if(ret->attr().at("is_gradient").b() != call->attr().at("is_gradient").b()) continue; + printf("Adding control edge from %s to %s\n",call->name().c_str(),ret->name().c_str()); + // TODO: Check if there is already a control dependency. + *(ret->add_input()) = AsControlDependency(call->name()); + } + } + + return OkStatus(); +} + +Status CallRewriter::TransformCall(const CallInfo& call_info) { + FuncGradInfo func_info; + TransformationResult result; + + // inlines the body of a function and provides a struct with func_info + TF_RETURN_IF_ERROR(FindCompatibleOrInlineFunction(call_info, graph, func_info)); + + result.call_id = call_info.call_id; + result.call_frame = call_info.call_frame; + result.transformed_node = call_info.fcall; + + TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.fcall, func_info.f, result.call_nodes, result.ret_nodes,false)); + MarkTransformed(result); + + if (call_info.hasGradient()) { + TransformationResult grad_result; + grad_result.call_id = call_info.call_id; + grad_result.call_frame = call_info.call_frame; + grad_result.transformed_node = call_info.gcall; + grad_result.call_nodes = result.call_nodes; + grad_result.ret_nodes = result.ret_nodes; + // keep all the inputs of the function + TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.gcall, func_info.g, grad_result.call_nodes, grad_result.ret_nodes,true)); + MarkTransformed(grad_result); + } + MarkCallTransformed(call_info); + return OkStatus(); +} + +Status CallRewriter::FindCompatibleOrInlineFunction( + const CallInfo& call, + GraphDef* graph, + FuncGradInfo& func_info) { + CHECK_NOTNULL(call.fcall); + const string& func_name = call.fcall->op(); + string device = call.fcall->device(); + const auto& it = transformed_functions_.find(func_name); + // maybe it is not wise to discard call attributes + // possible type specialization? + if (it != transformed_functions_.end()) { + func_info = it->second; + return OkStatus(); + } + const FunctionDef* func_def = ctx.FindInlinedFunction(func_name); + if (func_def == nullptr) { + return errors::InvalidArgument( + "Invalid argument, function ", func_name, "can not be found", + "or not marked to be inlined"); + } + + const AttrSlice func_instantiation_attr = + FunctionInstantiationAttributes(*func_def, *call.fcall); + + if (call.hasGradient()) { + TF_RETURN_IF_ERROR( + InlineFunctionAndGradient(*func_def, func_instantiation_attr, ctx, device, graph, func_info)); + } else { + TF_RETURN_IF_ERROR( + InlineFunction(*func_def, func_instantiation_attr, ctx, device, graph, func_info)); + } + transformed_functions_[func_name] = func_info; + printf("Store inlined function %s\n", func_name.c_str()); + return OkStatus(); +} + +void CallRewriter::Flush() { + + if (!transformed_calls_.empty()) { + // garbage collect the transformed call nodes + int last = graph->node_size() - 1; + for (int i = graph->node_size() - 1; i >= 0; --i) { + const NodeDef& node = graph->node(i); + if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) { + graph->mutable_node()->SwapElements(i,last); + last--; + } + } + graph->mutable_node()->DeleteSubrange(last + 1, + graph->node_size() - last - 1); + } + + + // for(auto& p : output_map_){ + // printf("%s -> %s\n",p.first.c_str(),p.second.c_str()); + + // } + + if (!output_map_.empty()) { + for (NodeDef& node : *graph->mutable_node()) { + std::vector control_nodes; + int last = node.input_size() - 1; + + for (int i = node.input_size() - 1; i >= 0; --i) { + string& in = *node.mutable_input(i); + auto it = output_map_.find(in); + if (it != output_map_.end()) { + in = it->second; + } + if (IsControlInput(in)) { + auto it = transformed_calls_.find(NodeName(in)); + if (it != transformed_calls_.end()) { + node.mutable_input()->SwapElements(i, last); + control_nodes.push_back(it->second); + last--; + } + } + node.mutable_input()->DeleteSubrange(last + 1, + node.input_size() - last - 1); + for (TransformationResult& result : control_nodes) { + for (NodeDef* ret_node : result.ret_nodes) { + *node.add_input() = AsControlDependency(ret_node->name()); + } + } + } + } + } + transformed_calls_.clear(); + nodes_to_delete.clear(); + output_map_.clear(); +} + +} // namespace + +Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + FunctionInliningContext ctx(item); + CallRewriter call_rewriter(item, output, ctx); + + *output = item.graph; + + printf("Before optimizer: %s\n\n",SummarizeGraphDef(*output).c_str()); + if (!ctx.HasInlinedFunctions()) { + return OkStatus(); + } + + std::vector calls; + while (1) { + TF_RETURN_IF_ERROR(call_rewriter.CollectCalls(calls)); + if (calls.empty()) { + break; + } + for (const CallInfo& call : calls) { + Status s = call_rewriter.TransformCall(call); + if (!s.ok()) { + printf("Error: %s\n", tsl::NullTerminatedMessage(s)); + return s; + } + printf("After transforming call %s:\n %s\n", call.fcall->name().c_str(), SummarizeGraphDef(*output).c_str()); + } + calls.clear(); + call_rewriter.Flush(); + } + call_rewriter.Flush(); + + + printf("After finalizing:\n %s\n", SummarizeGraphDef(*output).c_str()); + + // for (NodeDef& node : *output->mutable_node()){ + + // if(node.op() != kGradientOp)continue; + // NameAttrList func; + // TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node), kFuncAttrName, &func)); + // gradient::Creator creator; + // TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator)); + // if (creator == nullptr) { + // return absl::InvalidArgumentError( + // absl::StrCat("No gradient is defined for ", func.name())); + // } + // FunctionDef grad_fdef; + + // std::unique_ptr* fbody; + // TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef)); + // TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + // grad_fdef, AttrSlice(&func.attr()), &ctx.FunctionLibrary(), fbody)); + + // printf("Gradient of of %s:\n%s\n\n",func.name().c_str(),SummarizeGraphDef((*fbody)->graph->ToGraphDefDebug()).c_str()); + + + // } + + + + + + + *output->mutable_versions() = item.graph.versions(); + + // Function Library should be pruned of unreachable function definitions + // cf. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/function_optimizer.cc#L428 + // however in this version there is a check in meta_optimizer that guarantees + // that function library remains of the same length + // cf. https://github.com/acharal/tensorflow/blob/r1.4_recursion/tensorflow/core/grappler/optimizers/meta_optimizer.cc#L132 + *output->mutable_library() = item.graph.library(); + + + + /******************************************************************************************************/ + // Dumps optimized graph in a not so readable form + // const GraphDef* tmp = optimized_graph; + // printf("Summarize Optimized Graph\n %s\n", SummarizeGraphDef(*tmp).c_str()); + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("TRANSFORMATION"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + const size_t proto_size = output->ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + output->GetTypeName(), "' and size ", proto_size); + } + output->SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + /******************************************************************************************************/ + + return OkStatus(); +} + +} // end namespace grappler +} // end namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/grappler/optimizers/function_transformation.h b/tensorflow/core/grappler/optimizers/function_transformation.h new file mode 100644 index 00000000000000..2136ac5ccf7236 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_transformation.h @@ -0,0 +1,41 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + + +// Replace function calling nodes with pairs of new 'Call/Return' operators +class FunctionTransformation : public GraphOptimizer { + public: + explicit FunctionTransformation() {} + ~FunctionTransformation() override = default; + + string name() const override { return "function_transformation"; }; + + bool UsesFunctionLibrary() const override { return true; } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ diff --git a/tensorflow/core/grappler/optimizers/function_transformation_test.cc b/tensorflow/core/grappler/optimizers/function_transformation_test.cc new file mode 100644 index 00000000000000..751278545984e2 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_transformation_test.cc @@ -0,0 +1,57 @@ +/* + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_transformation.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { +namespace { + + +class FunctionTransformationTest : public ::testing::Test { + +}; + +TEST_F(FunctionTransformationTest, NoTrans) { + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Const(s.WithOpName("a"), 1.0f, {1}); + Output b = ops::Const(s.WithOpName("b"), 2.0f, {1}); + Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b}); + Output d = ops::AddN(s.WithOpName("d"), {b, c}); + + GrapplerItem item; + item.fetch.push_back("d"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + FunctionTransformation func_trans; + GraphDef output; + Status status = func_trans.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 84afab6e12badf..f2a1e80e4f09e5 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/debug_stripper.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h" +#include "tensorflow/core/grappler/optimizers/function_transformation.h" #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/implementation_selector.h" #include "tensorflow/core/grappler/optimizers/loop_optimizer.h" @@ -107,7 +108,7 @@ int NumIterations(const RewriterConfig& cfg) { bool IsRunOnceOptimizer(const string& name) { return name == "layout" || name == "memory_optimizer" || name == "loop_optimizer" || - absl::StartsWith(name, "auto_mixed_precision"); + absl::StartsWith(name, "auto_mixed_precision") || name == "function_transformation"; } // Creates a function library stub from a real function library: copy only @@ -211,6 +212,12 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types); if (optimizer == "pruning" && !plugin_configs.disable_model_pruning) return std::unique_ptr(new ModelPruner()); + + + if (LowerControlFlow()) { + MK_OPT("function_transformation", "function_transformation", new FunctionTransformation()); + } + MK_OPT("function", "function_optimization", new FunctionOptimizer(cfg_.function_optimization(), /*lower_control_flow=*/LowerControlFlow())); @@ -330,6 +337,14 @@ Status MetaOptimizer::InitializeOptimizers( else optimizers->push_back(std::make_unique()); } + if (BOTH_NOT_OFF(function_transformation)) { + if (USER_IS_EXPERIMENTAL_MLIR(function_transformation) || + USER_IS_EXPERIMENTAL_BOTH(function_transformation)) { + VLOG(2) << "function_transformation is not implemented in TFG yet"; + } else { + optimizers->push_back(std::make_unique()); + } + } if (BOTH_NOT_OFF(function_optimization)) { if (USER_IS_EXPERIMENTAL_MLIR(function_optimization) || USER_IS_EXPERIMENTAL_BOTH(function_optimization)) { @@ -641,6 +656,7 @@ void MetaOptimizer::PrintUserAndPluginConfigs( PRINT_CFG(loop_optimization) PRINT_CFG(dependency_optimization) PRINT_CFG(scoped_allocator_optimization) + PRINT_CFG(function_transformation) #undef PRINT_CFG user_cfg.toggle_config["auto_mixed_precision"] = AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision()) @@ -696,6 +712,7 @@ void MetaOptimizer::PrintUserAndPluginConfigs( PRINT_CFG("memory", "memory_optimization") PRINT_CFG("autoparallel", "auto_parallel") PRINT_CFG("scoped_allocator", "scoped_allocator_optimization") + PRINT_CFG("function_transformation", "function_transformation") #undef PRINT_CFG } } @@ -759,12 +776,12 @@ Status MetaOptimizer::OptimizeGraph( Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph) { int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes : cfg_.min_graph_nodes(); - if (item.graph.node_size() < min_graph_nodes) { - VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes - << " nodes."; - *optimized_graph = item.graph; - return absl::OkStatus(); - } + // if (item.graph.node_size() < min_graph_nodes) { + // VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes + // << " nodes."; + // *optimized_graph = item.graph; + // return OkStatus(); + // } tensorflow::metrics::ScopedCounter<2> timings( tensorflow::metrics::GetGraphOptimizationCounter(), @@ -815,12 +832,12 @@ Status MetaOptimizer::OptimizeGraph( for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) { // Don't bother optimizing further if the graph is already tiny. - if (optimized_graph->node_size() < min_graph_nodes) { - VLOG(3) << "Stopping after iteration " << iteration - << ", graph is tiny (#nodes = " << optimized_graph->node_size() - << " < " << min_graph_nodes << ")"; - break; - } + // if (optimized_graph->node_size() < min_graph_nodes) { + // VLOG(3) << "Stopping after iteration " << iteration + // << ", graph is tiny (#nodes = " << optimized_graph->node_size() + // << " < " << min_graph_nodes << ")"; + // break; + // } VLOG(4) << "Starting optimization iteration " << iteration; if (VLOG_IS_ON(4)) { @@ -1353,6 +1370,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) { rewrite_cfg.auto_parallel().enable() || rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || rewrite_cfg.debug_stripper() == RewriterConfig::ON || + rewrite_cfg.function_transformation() != RewriterConfig::OFF || #ifndef ENABLE_MKL rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON || #endif diff --git a/tensorflow/core/grappler/optimizers/my_bad_transformation.cc b/tensorflow/core/grappler/optimizers/my_bad_transformation.cc new file mode 100644 index 00000000000000..e315e3b12b08ad --- /dev/null +++ b/tensorflow/core/grappler/optimizers/my_bad_transformation.cc @@ -0,0 +1,113 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_transformation.h" +#include +#include +#include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" + +namespace tensorflow { +namespace grappler { +namespace { + + +class CallRewriter { + + +}; + + + +Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + FunctionInliningContext ctx(item); + CallRewriter call_rewriter(item, output, ctx); + + *output = item.graph; + if (!ctx.HasInlinedFunctions()) { + return OkStatus(); + } + + std::vector calls; + while (1) { + TF_RETURN_IF_ERROR(call_rewriter.CollectCalls(calls)); + if (calls.empty()) { + break; + } + for (CallInfo& call : calls) { + Status s = call_rewriter.TransformCall(call); + if (!s.ok()) { + printf("Error: %s\n", tsl::NullTerminatedMessage(s)); + return s; + } + printf("After transforming call %s:\n %s\n", call.function_name.c_str(), SummarizeGraphDef(*output).c_str()); + } + calls.clear(); + call_rewriter.Flush(); + } + call_rewriter.Flush(); + printf("After finalizing:\n %s\n", SummarizeGraphDef(*output).c_str()); + *output->mutable_versions() = item.graph.versions(); + + // Function Library should be pruned of unreachable function definitions + // cf. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/function_optimizer.cc#L428 + // however in this version there is a check in meta_optimizer that guarantees + // that function library remains of the same length + // cf. https://github.com/acharal/tensorflow/blob/r1.4_recursion/tensorflow/core/grappler/optimizers/meta_optimizer.cc#L132 + *output->mutable_library() = item.graph.library(); + + + + /******************************************************************************************************/ + // Dumps optimized graph in a not so readable form + // const GraphDef* tmp = optimized_graph; + // printf("Summarize Optimized Graph\n %s\n", SummarizeGraphDef(*tmp).c_str()); + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("TRANSFORMATION"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + const size_t proto_size = output->ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + output->GetTypeName(), "' and size ", proto_size); + } + output->SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + /******************************************************************************************************/ + + return OkStatus(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index c2d848aaa67ae6..94b7e3cadc39f0 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -608,7 +608,7 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item, // Skip original `_Arg` and `_Retval` nodes. If node was converted to some // other type (e.g. inputs converted to placeholders), we need to check that // it's not registered as function input or output node. - if (IsArg(func_node) || IsRetval(func_node) || + if (IsArg(func_node) || IsRetval(func_node) || IsReturn(func_node) || helper.IsInputNode(func_node) || helper.IsOutputNode(func_node)) continue; diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index 29e00240028715..1d2d7b7bc681ff 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -56,6 +56,7 @@ Status ComputeTopologicalOrder( // Keep track of how many inputs are ready for the given node. std::vector num_ready_inputs(graph.node_size(), 0); + std::unordered_map> returning_nodes; // We'll push index of ready nodes to this output vector. ready_nodes->reserve(graph.node_size()); @@ -68,12 +69,50 @@ Status ComputeTopologicalOrder( ready_nodes->push_back(i); back++; } + bool recursion_merge = false; if (IsMerge(graph.node(i))) { - for (int input : graph_view.GetFanin(i)) { + for (int input : graph_view.GetFanin(i)) { if (IsNextIteration(graph.node(input))) { num_ready_inputs[i]++; } + else if (IsCall(graph.node(input))) { + num_ready_inputs[i]++; + recursion_merge = true; + } } + if (recursion_merge) { + num_ready_inputs[i]--; + recursion_merge = false; + } + } else if (IsReturn(graph.node(i))) { + // Nodes that send their output to "Return" nodes are + // function Returning Nodes and in case of recursive functions + // those nodes are part of graph cycles. + int id = 0; + num_ready_inputs[i] = 0; + for (int input : graph_view.GetFanin(i)) { + // In order to detect the recursion cycles we depend on + // the fact that a recursive function's returning node, + // will be sending outputs to at least 2 "Return" nodes + // with different "call_id" attributes (same "call_id" + // attrs would mean that they belong in the same function call + // but they correspond to different function outputs) + if (!absl::StartsWith(graph.node(i).input(id), "^")) { + // if (true) { + int call_id; + TF_CHECK_OK(GetNodeAttr(graph.node(i), "call_id", &call_id)); + returning_nodes[input].emplace(call_id); + num_ready_inputs[i]++; + } + id++; + } + } + } + + for (const auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + // num_ready_inputs[retnode.first]++; } } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c541597c8d9d80..ffd85a9c0766f8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2528,6 +2528,15 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "function_control_ops", + prefix = "function_control_ops", + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + cc_library( name = "data_flow", deps = [ diff --git a/tensorflow/core/kernels/function_control_ops.cc b/tensorflow/core/kernels/function_control_ops.cc new file mode 100644 index 00000000000000..89d5a356427031 --- /dev/null +++ b/tensorflow/core/kernels/function_control_ops.cc @@ -0,0 +1,191 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/function_control_ops.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +void CallOpe::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } +} + +REGISTER_KERNEL_BUILDER(Name("Call").Device(DEVICE_CPU), CallOpe); +REGISTER_KERNEL_BUILDER(Name("RefCall").Device(DEVICE_CPU), CallOpe); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Call").Device(DEVICE_GPU).TypeConstraint("T"), CallOpe) +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefCall").Device(DEVICE_GPU).TypeConstraint("T"), CallOpe) + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); +REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_REF_KERNEL(bool); + +#undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL + +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Call").Device(DEVICE_SYCL).TypeConstraint("T"), CallOpe) +REGISTER_SYCL_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefCall").Device(DEVICE_SYCL).TypeConstraint("T"), CallOpe) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); + +#undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_REF_KERNEL +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Call") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefCall") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_REF_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +REGISTER_SYCL_HOST_REF_KERNEL(string); +REGISTER_SYCL_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_SYCL_HOST_KERNEL +#undef REGISTER_SYCL_HOST_REF_KERNEL +#endif // TENSORFLOW_USE_SYCL + +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Call") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +#define REGISTER_GPU_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefCall") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_REF_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(tstring); +REGISTER_GPU_HOST_REF_KERNEL(tstring); +REGISTER_GPU_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_GPU_HOST_KERNEL +#undef REGISTER_GPU_HOST_REF_KERNEL + +void ReturnOp::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } +} + +REGISTER_KERNEL_BUILDER(Name("Return").Device(DEVICE_CPU), ReturnOp); +REGISTER_KERNEL_BUILDER(Name("RefReturn").Device(DEVICE_CPU), ReturnOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Return").Device(DEVICE_GPU).TypeConstraint("T"), ReturnOp); +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefReturn").Device(DEVICE_GPU).TypeConstraint("T"), ReturnOp); + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); +REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_REF_KERNEL(bool); + +#undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL + +#ifdef TENSORFLOW_USE_SYCL + #define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Return").Device(DEVICE_SYCL).TypeConstraint("T"), ReturnOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefReturn").Device(DEVICE_SYCL).TypeConstraint("T"), ReturnOp); +REGISTER_SYCL_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_REF_KERNEL + +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Return") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp); \ + REGISTER_KERNEL_BUILDER(Name("RefReturn") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +#undef REGISTER_SYCL_HOST_KERNEL +#endif // TENSORFLOW_USE_SYCL + +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Return") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp); \ + REGISTER_KERNEL_BUILDER(Name("RefReturn") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(tstring); + +#undef REGISTER_GPU_HOST_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/function_control_ops.h b/tensorflow/core/kernels/function_control_ops.h new file mode 100644 index 00000000000000..b03d3eae9a39c5 --- /dev/null +++ b/tensorflow/core/kernels/function_control_ops.h @@ -0,0 +1,47 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_KERNELS_FUNCTION_CONTROL_OPS_H_ +#define TENSORFLOW_KERNELS_FUNCTION_CONTROL_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// A call op has one input and one output. It creates or finds +// the child frame that is uniquely identified by the frame_name, +// and makes its input available to the child frame. +class CallOpe : public OpKernel { +public: + explicit CallOpe(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~CallOpe() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(CallOpe); +}; + +// A Return op has one input and one output. It exits the current +// frame to its parent frame, and makes its input available to the +// parent frame only if it receives a tensor with a specific tag. +class ReturnOp : public OpKernel { +public: + explicit ReturnOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~ReturnOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(ReturnOp); +}; +} // namespace tensorflow + +#endif diff --git a/tensorflow/core/ops/BUILD b/tensorflow/core/ops/BUILD index c1e0497969e7fc..44fbb61eea480a 100644 --- a/tensorflow/core/ops/BUILD +++ b/tensorflow/core/ops/BUILD @@ -58,6 +58,7 @@ tf_gen_op_libs( "filesystem_ops", "function_ops", "functional_ops", + "function_control_ops", "image_ops", "io_ops", "linalg_ops", @@ -293,6 +294,7 @@ cc_library( ":experimental_dataset_ops_op_lib", ":filesystem_ops_op_lib", ":function_ops_op_lib", + ":function_control_ops_op_lib", ":functional_ops_op_lib", ":image_ops_op_lib", ":io_ops_op_lib", diff --git a/tensorflow/core/ops/function_control_ops.cc b/tensorflow/core/ops/function_control_ops.cc new file mode 100644 index 00000000000000..c337160b0fa5da --- /dev/null +++ b/tensorflow/core/ops/function_control_ops.cc @@ -0,0 +1,116 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- +REGISTER_OP("Call") + .Input("data: T") + .Output("output: T") + .Attr("T: type") + .Attr("frame_name: string") + .Attr("call_id: int") + .Attr("arg_id: int") + .Attr("is_constant: bool = false") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->UnknownShape()); + + // Handle resource shape / dtype, if present. + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr) { + c->set_output_handle_shapes_and_types(0, *handle_data); + } else { + // Otherwise, propagate shape if output is a constant. + bool is_constant; + TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant)); + if (is_constant) { + c->set_output(0, c->input(0)); + } + } + return OkStatus(); + }) + .Doc(R"Doc( +Creates (or finds) a child frame, and makes `data` available to the child frame. + +This op is used together with `Return` to create recursive calls in the graph. +The unique `frame_name` is used by the `Executor` to identify frames. + +data: The tensor to be made available to the child frame. +frame_name: The name of the child frame. +output: The same tensor as `data`. + +Returns tensors with the same shapes and contents as the input +tensors. + )Doc"); + +REGISTER_OP("RefCall") + .Input("data: Ref(T)") + .Output("output: Ref(T)") + .Attr("T: type") + .Attr("frame_name: string") + .Attr("call_id: int") + .Attr("arg_id: int") + .Attr("is_constant: bool = false") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"Doc( +Creates (or finds) a child frame, and makes `data` available to the child frame. + +This op is used together with `Return` to create recursive calls in the graph. +The unique `frame_name` is used by the `Executor` to identify frames. + +data: The tensor to be made available to the child frame. +frame_name: The name of the child frame. +output: The same tensor as `data`. + +Returns tensors with the same shapes and contents as the input +tensors. + )Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Return") +.Input("data: T") +.Output("output: T") +.Attr("T: type") +.Attr("frame_name: string") +.Attr("call_id: int") +.Attr("arg_id: int") +.SetShapeFn(shape_inference::UnchangedShape) +.Doc(R"Doc( +Exits the current frame to its parent frame. +Exit makes its input `data` available to the parent frame. +data: The list of tensors to be made available to the parent frame. +output: The same list of tensors as `data`. + )Doc"); + +REGISTER_OP("RefReturn") +.Input("data: Ref(T)") +.Output("output: Ref(T)") +.Attr("T: type") +.Attr("frame_name: string") +.Attr("call_id: int") +.Attr("arg_id: int") +.SetShapeFn(shape_inference::UnchangedShape) +.Doc(R"Doc( +Exits the current frame to its parent frame. +Exit makes its input `data` available to the parent frame. +data: The tensors to be made available to the parent frame. +output: The same tensors as `data`. + )Doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index f98d1928d9e156..aedf6a8bb0fa69 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -92,6 +92,8 @@ message RewriterConfig { Toggle function_optimization = 10; // Strips debug-related nodes from the graph (off by default). Toggle debug_stripper = 11; + // Function transformation (default is ON). + Toggle function_transformation = 33; // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; // Try to allocate some independent Op outputs contiguously in order to diff --git a/tensorflow/python/client/_pywrap_tf_session.pyi b/tensorflow/python/client/_pywrap_tf_session.pyi index 14645b34c5f5be..ba00dd680250da 100644 --- a/tensorflow/python/client/_pywrap_tf_session.pyi +++ b/tensorflow/python/client/_pywrap_tf_session.pyi @@ -370,6 +370,7 @@ def TF_GetOpList(arg0: TF_Library) -> object: ... def TF_GetRegisteredKernelsForOp(arg0: str) -> TF_Buffer: ... def TF_GetXlaAutoJitEnabled() -> int: ... def TF_GetXlaConstantFoldingDisabled() -> int: ... +def TF_GraphAddFunctionDef(arg0: PyGraph, arg1: bytes) -> None: ... def TF_GraphCopyFunction(arg0: PyGraph, arg1: TF_Function, arg2: TF_Function) -> None: ... def TF_GraphImportGraphDefWithResults(arg0: PyGraph, arg1: TF_Buffer, arg2: TF_ImportGraphDefOptions) -> TF_ImportGraphDefResults: ... def TF_GraphImportGraphDefWithResultsNoSerialization(arg0: PyGraph, arg1, arg2: TF_ImportGraphDefOptions) -> TF_ImportGraphDefResults: ... diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index b2d3492f99dfd5..a2ef54ec9d6ec6 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -1867,6 +1867,21 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); }); + m.def("TF_GraphAddFunctionDef", + [](PyGraph* graph, py::bytes proto) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + tensorflow::Safe_TF_BufferPtr buf = + tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); + + // Release GIL. + py::gil_scoped_release release; + TF_GraphAddFunctionDef(graph->tf_graph(), buf.get()->data, buf.get()->length, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); + }); + + + m.def("TF_GraphCopyFunction", [](PyGraph* graph, const TF_Function* func, const TF_Function* grad) { tensorflow::Safe_TF_StatusPtr status = diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 848a4c8f23599f..68e4fc94c00673 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -21,6 +21,7 @@ import collections import hashlib +from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python.client import pywrap_tf_session as c_api @@ -152,6 +153,7 @@ def __init__(self, *input_types, **kwargs): self._func_name = kwargs.pop("func_name", None) self._grad_func = kwargs.pop("grad_func", None) self._python_grad_func = kwargs.pop("python_grad_func", None) + self._create_grad_func = kwargs.pop("create_grad_func", False) self._out_names = kwargs.pop("out_names", None) self._extra_kwargs = kwargs @@ -196,6 +198,8 @@ def __call__(self, func): self._func_name, self._grad_func, self._python_grad_func, + self._create_grad_func, + is_gradient=False, out_names=self._out_names, **self._extra_kwargs) @@ -245,6 +249,61 @@ def __del__(self): # been unloaded. Will catch other module unloads as well. +class Declare(object): + """Declares a TensorFlow function. + + The object represents a TensorFlow function which will be defined + later during a graph construction. + + For example, + # Declares a function Foo, which takes a tf.int32 named "n" and a + # tf.float32 named "x" as inputs and returns a tf.float32 named "z" + # as its output. + foo = Declare("Foo", [("n", tf.int32), ("x", tf.float32)], + [("z", tf.float32)]) + + # Defines a function Bar calls Foo. + @tf.Defun(tf.float32) + def Bar(x): + return foo(6, x) + + # Defines Foo, with output named "z". + @tf.Defun(tf.int32, tf.float32, out_names=["z"]) + def Foo(n, x): + ... # Calculation. + return result + """ + + + def __init__(self, func_name, inputs, outputs): + """Creates a `Declare` object. + + Args: + func_name: The name of the function. + inputs: A list of (name, data type) pairs of function arguments. + outputs: A list of (name, data type) pairs of function return values. + """ + self._sig = op_def_pb2.OpDef() + self._sig.name = func_name + + def _to_argdef_list(args): + names = [n for n, t in args] + if len(names) != len(set(names)): + raise ValueError("Expected names to all be unique: %s" % str(names)) + return [ + op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n) + for n, t in args + ] + + self._sig.input_arg.extend(_to_argdef_list(inputs)) + self._sig.output_arg.extend(_to_argdef_list(outputs)) + + def __call__(self, *inputs, **kwargs): + inputs = [ops.convert_to_tensor(_) for _ in inputs] + return _call(self._sig, *inputs, **kwargs)[0] + + + class _DefinedFunction(object): """_DefinedFunction encapsulates a function definition and its properties. @@ -264,6 +323,8 @@ def __init__(self, func_name=None, grad_func=None, python_grad_func=None, + create_grad_func=False, + is_gradient=False, out_names=None, shape_func=None, capture_by_value=False, @@ -305,6 +366,8 @@ def __init__(self, self._func_name = func_name self._grad_func = grad_func self._python_grad_func = python_grad_func + self._create_grad_func = create_grad_func + self._is_gradient = is_gradient self._out_names = out_names self._shape_func = shape_func self._capture_by_value = capture_by_value @@ -331,12 +394,31 @@ def __init__(self, # is disabled the whole _definition is available and this is simply # another reference to _definition.signature self._op_def = None - + assert isinstance(input_types, (list, tuple)) self._arg_types = input_types self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) for i in range(len(input_types))] + + self._args = list(zip(self._arg_names,self._arg_types)) + if self._create_grad_func: + grad_func_name = self._func_name #+ "Grad" + out_names = self._out_names.copy() + for (argname, argtype) in self._args: + out_names.append("d" + argname) + # Todo: check if we need to copy all the args so that they don't get passed by reference + self._grad_func = _DefinedFunction(func=func, + argnames=argnames, + input_types=input_types, + func_name=grad_func_name, + grad_func=None, + python_grad_func=None, + create_grad_func=False, + is_gradient=True, + out_names=out_names, + **kwargs) + @property def name(self): """Function name.""" @@ -415,7 +497,8 @@ def _create_definition_if_needed(self): def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" - if self._definition is not None or self._c_func is not None: + if self._definition is not None or self._c_func is not None \ + or (self._is_gradient and not ops.get_default_graph()._is_function(self._func_name)): return # Copy variable collections (by reference) from the parent graph such that @@ -433,18 +516,25 @@ def _create_definition_if_needed_impl(self): self._func, self._arg_names, self._arg_types, + self._out_names, self._func_name, + self._is_gradient, self._capture_by_value, self._caller_device, collections_ref=collections_ref, allowlisted_stateful_ops=self._allowlisted_stateful_ops, - capture_resource_var_by_value=self._capture_resource_var_by_value) + capture_resource_var_by_value=self._capture_resource_var_by_value, + functions=parent_graph._functions + ) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access + if self._is_gradient and self._func_name: + self._func_name += "Grad" + # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name @@ -945,7 +1035,9 @@ def _add_op_and_parents(self, op: ops.Operation): def func_graph_from_py_func(func, arg_names, arg_types, + out_names, name=None, + is_gradient=False, capture_by_value=False, device=None, colocation_stack=None, @@ -953,7 +1045,8 @@ def func_graph_from_py_func(func, collections_ref=None, arg_shapes=None, allowlisted_stateful_ops=None, - capture_resource_var_by_value=True): + capture_resource_var_by_value=True, + functions=None): """Returns a _FuncGraph generated from `func`. Args: @@ -1006,7 +1099,27 @@ def func_graph_from_py_func(func, func_graph.inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=func_graph.getvar): - outputs = func(*func_graph.inputs) + gradient_out_types = [] + if is_gradient: + name = name + "Grad" + outputs = func(*func_graph.inputs) + if not isinstance(outputs,list): + outputs = [outputs] + dinputs = [] + for (out, name) in list(zip(outputs, out_names)): + argholder = array_ops.placeholder(out.op.node_def.attr["T"].type, name="d"+name) + dinputs.append(argholder) + gradient_out_types.append(out.op.node_def.attr["T"].type) + for argtype in arg_types: + gradient_out_types.append(argtype) + from tensorflow.python.ops import gradients_impl + doutputs = gradients_impl.gradients(outputs, func_graph.inputs, dinputs, functions = functions) + if not isinstance(doutputs, list): + doutputs = [doutputs] + outputs.extend(doutputs) + func_graph.inputs.extend(dinputs) + else: + outputs = func(*func_graph.inputs) # There is no way of distinguishing between a function not returning # anything and a function returning None in Python. @@ -1022,10 +1135,24 @@ def func_graph_from_py_func(func, # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) - if any(_ is None for _ in outputs): - raise ValueError(f"Function {name} can not return None.") + # if any(_ is None for _ in outputs): + # raise ValueError(f"Function {name} can not return None.") # Ensures each output is a Tensor in the function graph. - outputs = [ops.convert_to_tensor(t) for t in outputs] + if is_gradient: + tmp_out = [] + for out, out_type in zip(outputs, gradient_out_types): + if out is not None: + tmp_out.append(ops.convert_to_tensor(out)) + else: + if out_type.is_bool: + tmp_out.append(ops.convert_to_tensor(False)) + elif out_type.is_floating: + tmp_out.append(ops.convert_to_tensor(0.0)) + else: + tmp_out.append(ops.convert_to_tensor(0)) + outputs = tmp_out + else: + outputs = [ops.convert_to_tensor(t) for t in outputs] outputs = [func_graph.capture(t) if t.graph is not func_graph else t for t in outputs] func_graph.outputs = outputs diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 34b1eed754bbed..f300fd1caad572 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2506,6 +2506,18 @@ def _add_function_recursive(self, function, overwrite=False) -> None: else: self._add_function(f) + def _declare_function_from_op_def(self, op_def) -> None: + + function_def = function_pb2.FunctionDef() + function_def.signature.CopyFrom(op_def) + + with self._c_graph.get() as c_graph: + try: + pywrap_tf_session.TF_GraphAddFunctionDef(c_graph,function_def.SerializeToString()) + except errors.InvalidArgumentError: + pass + + def _add_function(self, function) -> None: """Adds a function to the graph. @@ -2532,12 +2544,12 @@ def _add_function(self, function) -> None: # pylint: disable=protected-access with self._c_graph.get() as c_graph: with function._c_func.get() as func: - if getattr(function, "_grad_func", None): - # For deprecated _DefinedFunction. - with function._grad_func._c_func.get() as gradient: - pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, gradient) - else: - pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, None) + # if getattr(function, "_grad_func", None): + # # For deprecated _DefinedFunction. + # with function._grad_func._c_func.get() as gradient: + # pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, gradient) + # else: + pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, None) # pylint: enable=protected-access self._functions[compat.as_str(name)] = function @@ -2676,6 +2688,11 @@ def _create_op_internal( input_ops = set(t.op for t in inputs) control_inputs = self._control_dependencies_for_inputs(input_ops) + + if op_def: + self._declare_function_from_op_def(op_def) + + # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a # Session.run call cannot occur between creating and mutating the op. with self._mutation_lock(): diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 14e316a6433f6e..187c71f8a7436b 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -61,7 +61,8 @@ def gradients(ys, gate_gradients=False, aggregation_method=None, stop_gradients=None, - unconnected_gradients=UnconnectedGradients.NONE): + unconnected_gradients=UnconnectedGradients.NONE, + functions = None): """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`. `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` @@ -181,7 +182,7 @@ def gradients(ys, return gradients_util._GradientsHelper( ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, - unconnected_gradients) + unconnected_gradients, functions) # pylint: enable=protected-access diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index fa568ea706cf36..e4647c5adc1776 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -512,6 +512,7 @@ def _GradientsHelper(ys, aggregation_method=None, stop_gradients=None, unconnected_gradients=UnconnectedGradients.NONE, + functions = None, src_graph=None): """Implementation of gradients().""" if context.executing_eagerly(): @@ -536,7 +537,7 @@ def _GradientsHelper(ys, flat_grads = _GradientsHelper(flat_ys, flat_xs, flat_grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, - unconnected_gradients, src_graph) + unconnected_gradients, src_graph = src_graph) return composite_tensor_gradient.replace_flat_tensors_for_gradients( xs, flat_grads) @@ -637,6 +638,9 @@ def _GradientsHelper(ys, is_partitioned_call = _IsPartitionedCall(op) # pylint: disable=protected-access is_func_call = src_graph._is_function(op.type) or is_partitioned_call + if not is_func_call and functions is not None: + is_func_call = op.type in functions + # pylint: enable=protected-access has_out_grads = any( isinstance(g, tensor_lib.Tensor) or g for g in out_grads @@ -665,6 +669,8 @@ def _GradientsHelper(ys, break else: func_call = src_graph._get_function(op.type) # pylint: disable=protected-access + if func_call is None and functions is not None: + func_call = functions.get(op.type,None) # Note that __defun is not set if the graph is # imported. If it's set, we prefer to access the original # defun.