diff --git a/.bazelrc b/.bazelrc
index 0322618b53f..ff910cd186e 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -149,6 +149,12 @@ build --experimental_cc_shared_library
# cc_shared_library ensures no library is linked statically more than once.
build --experimental_link_static_libraries_once=false
+# Prevent regressions on those two incompatible changes
+# TODO: remove those flags when they are flipped in the default Bazel version TF uses.
+build --incompatible_enforce_config_setting_visibility
+# TODO: also enable this flag after fixing the visbility violations
+# build --incompatible_config_setting_private_default_visibility
+
# Default options should come above this line.
# Allow builds using libc++ as a linker library
@@ -324,7 +330,9 @@ build:linux --copt="-Wunused-result"
# build:linux --copt="-Werror=unused-result"
# Add switch as an error on Linux.
build:linux --copt="-Wswitch"
-# build:linux --copt="-Werror=switch"
+build:linux --copt="-Werror=switch"
+# Required for building with clang
+build:linux --copt="-Wno-error=unused-but-set-variable"
# On Windows, `__cplusplus` is wrongly defined without this switch
# See https://devblogs.microsoft.com/cppblog/msvc-now-correctly-reports-__cplusplus/
@@ -382,8 +390,8 @@ build:windows --host_copt=-DNOGDI
# MSVC (Windows): Standards-conformant preprocessor mode
# See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview
-build:windows --copt=/experimental:preprocessor
-build:windows --host_copt=/experimental:preprocessor
+build:windows --copt=/Zc:preprocessor
+build:windows --host_copt=/Zc:preprocessor
# Misc build options we need for windows.
build:windows --linkopt=/DEBUG
@@ -559,8 +567,8 @@ build:rbe_linux_py3_base --python_path="/usr/local/bin/python3.9"
build:rbe_linux_py3_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9"
build:rbe_win --config=rbe
-build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_06152022:toolchain"
-build:rbe_win --extra_toolchains="//tensorflow/tools/toolchains/win/tf_win_06152022:cc-toolchain-x64_windows"
+build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_01232023:toolchain"
+build:rbe_win --extra_toolchains="//tensorflow/tools/toolchains/win/tf_win_01232023:cc-toolchain-x64_windows"
build:rbe_win --extra_execution_platforms="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019"
build:rbe_win --host_platform="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019"
build:rbe_win --platforms="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019"
@@ -672,6 +680,7 @@ build:asan --copt -g
build:asan --copt -O3
build:asan --copt -fno-omit-frame-pointer
build:asan --linkopt -fsanitize=address
+build:asan --@libjpeg_turbo//:noasm=yes
# Memory sanitizer
# CC=clang bazel build --config msan
@@ -695,7 +704,17 @@ build:ubsan --linkopt -fsanitize=undefined
build:ubsan --linkopt -lubsan
# Disable TFRT integration for now unless --config=tfrt is specified.
-build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/common,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils
+build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils
# TODO(b/240450920): We are in the process of migrating JitRt backend to XLA
# and while we are doing this we can't keep it buildable/testable in OSS.
-build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/common,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils
+build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils
+
+# TF Fuzztest config
+try-import fuzztest.bazelrc
+run:tf_fuzztest --config=fuzztest
+# Should aim to remove these
+build:tf_fuzztest --action_env=CC=clang
+build:tf_fuzztest --action_env=CXX=clang++
+build:tf_fuzztest --spawn_strategy=sandboxed
+build:tf_fuzztest --config=monolithic
+build:tf_fuzztest --@libjpeg_turbo//:noasm=yes
diff --git a/.bazelversion b/.bazelversion
index e230c8396d1..f53152b50eb 100644
--- a/.bazelversion
+++ b/.bazelversion
@@ -1 +1,2 @@
-5.3.0
\ No newline at end of file
+5.3.0
+# NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml b/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml
index 6e4753d8674..70bdc6160cb 100644
--- a/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml
+++ b/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml
@@ -23,6 +23,17 @@ body:
value: |
Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub.
+ - type: dropdown
+ id: tf-nightly
+ attributes:
+ label: Have you reproduced the bug with TF nightly?
+ description: It is strongly suggested that you have reproduced the bug with [TF nightly](https://www.tensorflow.org/install/pip#nightly)
+ options:
+ - "Yes"
+ - "No"
+ validations:
+ required: true
+
- type: markdown
attributes:
value: |
@@ -38,6 +49,7 @@ body:
- binary
validations:
required: true
+
- type: input
id: tfversion
attributes:
diff --git a/.github/ISSUE_TEMPLATE/tflite-other.md b/.github/ISSUE_TEMPLATE/tflite-other.md
new file mode 100644
index 00000000000..8b8246f2b72
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/tflite-other.md
@@ -0,0 +1,62 @@
+name: TensorFlow Lite Other Issue description: Use this template to report any
+issue in TensorFlow Lite that is not about Converters, Play Services or Ops
+body: - type: dropdown id: issue-type attributes: label: Issue Type description:
+What type of issue would you like to report? multiple: false options: - Bug -
+Build/Install - Performance - Support - Feature Request - Documentation Feature
+Request - Documentation Bug - Others validations: required: true - type:
+markdown attributes: value: | Please make sure that this is a bug. As per our
+[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),we
+only address code/doc bugs, performance issues, feature requests and
+build/installation issues on GitHub.
+
+- type: markdown
+ attributes:
+ value: |
+ You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
+
+- type: dropdown id: source attributes: label: Source description: Tensorflow
+ installed from options: - source - binary validations: required: true
+
+- type: input id: tfversion attributes: label: Tensorflow Version description:
+ placeholder: ex,. tf 2.8 validations: required: true
+
+- type: dropdown id: Code attributes: label: Custom Code description:
+ options: - "Yes" - "No" validations: required: true
+
+- type: input id: OS attributes: label: OS Platform and Distribution
+ description: placeholder: e.g., Linux Ubuntu 16.04 validations: required:
+ false
+
+- type: input id: Mobile attributes: label: Mobile device description:
+ placeholder: e.g., Linux Ubuntu 16.04 validations: required: false
+
+- type: input id: Python attributes: label: Python version description:
+ placeholder: e.g., 3.9 validations: required: false
+
+- type: input id: Bazel attributes: label: Bazel version description: if
+ compiling from source placeholder: validations: required: false
+
+- type: input id: Compiler attributes: label: GCC/Compiler version
+ description: if compiling from source placeholder: validations: required:
+ false
+
+- type: input id: Cuda attributes: label: CUDA/cuDNN version description:
+ placeholder: validations: required: false
+
+- type: input id: Gpu attributes: label: GPU model and memory description: if
+ compiling from source placeholder: validations: required: false
+
+- type: textarea id: what-happened attributes: label: Current Behaviour?
+ description: Also tell us, what did you expect to happen? placeholder: Tell
+ us what you see! value: "A bug happened!" render: shell validations:
+ required: true
+
+- type: textarea id: code-to-reproduce attributes: label: Standalone code to
+ reproduce the issue description: Provide a reproducible test case that is
+ the bare minimum necessary to generate the problem. If possible, please
+ share a link to Colab/Jupyter/any notebook. placeholder: Tell us what you
+ see! value: render: shell validations: required: true
+
+- type: textarea id: logs attributes: label: Relevant log output description:
+ Please copy and paste any relevant log output. This will be automatically
+ formatted into code, so no need for backticks. render: shell
diff --git a/.github/bot_config.yml b/.github/bot_config.yml
index 3f039b9e176..bab88af1a8e 100644
--- a/.github/bot_config.yml
+++ b/.github/bot_config.yml
@@ -15,9 +15,10 @@
# A list of assignees
assignees:
- - tilakrayal
+ - synandi
- tiruk007
- - Mohantym
+ - gaikwadrahul8
+ - pjpratik
# A list of assignees for compiler folder
compiler_assignees:
- joker-eph
diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml
index 1698cf0f0b3..b601b0054c7 100644
--- a/.github/workflows/arm-cd.yml
+++ b/.github/workflows/arm-cd.yml
@@ -26,9 +26,14 @@ jobs:
build:
if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks
runs-on: [self-hosted, linux, ARM64]
+ continue-on-error: ${{ matrix.experimental }}
strategy:
matrix:
- pyver: ['3.7', '3.8', '3.9', '3.10']
+ pyver: ['3.8', '3.9', '3.10']
+ experimental: [false]
+ include:
+ - pyver: '3.11'
+ experimental: true
steps:
- name: Stop old running containers (if any)
shell: bash
@@ -46,12 +51,12 @@ jobs:
run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true
- name: Checkout repository for nightly (skipped for releases)
if: ${{ github.event_name == 'schedule' }}
- uses: actions/checkout@v3
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
with:
ref: 'nightly'
- name: Checkout repository for releases (skipped for nightly)
if: ${{ github.event_name == 'push' }}
- uses: actions/checkout@v3
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
- name: Build and test pip wheel
shell: bash
run: |
diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml
index 0fcf49e340a..1592f4ed18a 100644
--- a/.github/workflows/arm-ci-extended.yml
+++ b/.github/workflows/arm-ci-extended.yml
@@ -50,7 +50,7 @@ jobs:
shell: bash
run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
- name: Build binary and run non-pip tests
shell: bash
run: |
diff --git a/.github/workflows/arm-ci.yml b/.github/workflows/arm-ci.yml
index 067e29131e7..e6ddbb9eec9 100644
--- a/.github/workflows/arm-ci.yml
+++ b/.github/workflows/arm-ci.yml
@@ -21,14 +21,15 @@ on:
- master
- r2.**
pull_request:
- types: [opened, synchronize, reopened]
+ types: [labeled, opened, synchronize, reopened]
branches:
- master
- r2.**
jobs:
build:
- if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks
+ # Don't do this in forks, and if labeled, only for 'kokoro:force-run'
+ if: github.repository == 'tensorflow/tensorflow' && (github.event.action != 'labeled' || (github.event.action == 'labeled' && github.event.label.name == 'kokoro:force-run'))
runs-on: [self-hosted, linux, ARM64]
strategy:
matrix:
@@ -49,14 +50,14 @@ jobs:
shell: bash
run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
- name: Build and test pip wheel
shell: bash
run: |
CI_DOCKER_BUILD_EXTRA_PARAMS='--build-arg py_major_minor_version=${{ matrix.pyver }}' \
./tensorflow/tools/ci_build/ci_build.sh cpu.arm64 bash tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh
- name: Upload pip wheel to GitHub
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # v3.1.1
with:
name: tensorflow_py${{ matrix.pyver }}_wheel
path: /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/whl/*.whl
diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml
index fdae2ac19e6..21ac759f3ef 100644
--- a/.github/workflows/cffconvert.yml
+++ b/.github/workflows/cffconvert.yml
@@ -27,9 +27,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out a copy of the repository
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
- name: Check whether the citation metadata from CITATION.cff is valid
- uses: citation-file-format/cffconvert-github-action@2.0.0
+ uses: citation-file-format/cffconvert-github-action@4cf11baa70a673bfdf9dad0acc7ee33b3f4b6084 # v2.0.0
with:
args: "--validate"
diff --git a/.github/workflows/issue-on-pr-rollback.yml b/.github/workflows/issue-on-pr-rollback.yml
index ce0182bedc2..fa76923a2ba 100644
--- a/.github/workflows/issue-on-pr-rollback.yml
+++ b/.github/workflows/issue-on-pr-rollback.yml
@@ -27,9 +27,9 @@ jobs:
startsWith(github.event.head_commit.message, 'Rollback of PR #')
steps:
- name: Checkout repo
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
- name: Create a new Github Issue
- uses: actions/github-script@v5
+ uses: actions/github-script@d556feaca394842dc55e4734bf3bb9f685482fa0 # v6.3.3
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml
index f1b539f551b..e97f34472d8 100644
--- a/.github/workflows/pylint-presubmit.yml
+++ b/.github/workflows/pylint-presubmit.yml
@@ -25,17 +25,17 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
- name: Get file changes
id: get_file_changes
- uses: trilom/file-changes-action@v1.2.4
+ uses: trilom/file-changes-action@a6ca26c14274c33b15e6499323aac178af06ad4b # v1.2.4
with:
output: ' '
- name: Report list of changed files
run: |
echo Changed files: ${{ steps.get_file_changes.outputs.files }}
- name: Set up Python 3.9
- uses: actions/setup-python@v2
+ uses: actions/setup-python@2c3dd9e7e29afd70cc0950079bde6c979d1f69f9 # v4.3.1
with:
python-version: "3.9"
- name: Install Python dependencies
diff --git a/.github/workflows/release-branch-cherrypick.yml b/.github/workflows/release-branch-cherrypick.yml
index a57852a9644..5ff69e46805 100644
--- a/.github/workflows/release-branch-cherrypick.yml
+++ b/.github/workflows/release-branch-cherrypick.yml
@@ -42,7 +42,7 @@ jobs:
if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks
steps:
- name: Checkout code
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
with:
ref: ${{ github.event.inputs.release_branch }}
- name: Get some helpful info for formatting
@@ -52,10 +52,10 @@ jobs:
git config --global user.email "jenkins@tensorflow.org"
git fetch origin master
git cherry-pick ${{ github.event.inputs.git_commit }}
- echo ::set-output name=SHORTSHA::$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h")
- echo ::set-output name=TITLE::$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s")
+ echo "SHORTSHA=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h")" >> "$GITHUB_OUTPUT"
+ echo "TITLE=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s")" >> "$GITHUB_OUTPUT"
- name: Create Pull Request with changes
- uses: peter-evans/create-pull-request@v3
+ uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3
with:
title: '${{ github.event.inputs.release_branch }} cherry-pick: ${{ steps.cherrypick.outputs.SHORTSHA }} "${{ steps.cherrypick.outputs.TITLE }}"'
committer: TensorFlow Release Automation
diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml
index 8f9dab872b6..1c520aa86fd 100644
--- a/.github/workflows/scorecards-analysis.yml
+++ b/.github/workflows/scorecards-analysis.yml
@@ -34,23 +34,18 @@ jobs:
# Needed to upload the results to code-scanning dashboard.
security-events: write
id-token: write
- actions: read
- contents: read
steps:
- name: "Checkout code"
- uses: actions/checkout@ec3a7ce113134d7a93b817d10a8272cb61118579 # v2.4.0
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
with:
persist-credentials: false
- name: "Run analysis"
- uses: ossf/scorecard-action@08dd0cebb088ac0fd6364339b1b3b68b75041ea8 # v2.0.0-alpha.2
+ uses: ossf/scorecard-action@15c10fcf1cf912bd22260bfec67569a359ab87da # v2.1.1
with:
results_file: results.sarif
results_format: sarif
- # Read-only PAT token. To create it,
- # follow the steps in https://github.com/ossf/scorecard-action#pat-token-creation.
- repo_token: ${{ secrets.SCORECARD_READ_TOKEN }}
# Publish the results to enable scorecard badges. For more details, see
# https://github.com/ossf/scorecard-action#publishing-results.
# For private repositories, `publish_results` will automatically be set to `false`,
@@ -59,7 +54,7 @@ jobs:
# Upload the results as artifacts (optional).
- name: "Upload artifact"
- uses: actions/upload-artifact@82c141cc518b40d92cc801eee768e7aafc9c2fa2 # v2.3.1
+ uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # v3.1.1
with:
name: SARIF file
path: results.sarif
@@ -67,6 +62,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
- uses: github/codeql-action/upload-sarif@5f532563584d71fdef14ee64d17bafb34f751ce5 # v1.0.26
+ uses: github/codeql-action/upload-sarif@896079047b4bb059ba6f150a5d87d47dde99e6e5 # v2.11.6
with:
sarif_file: results.sarif
diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml
index 41b0fe5a13a..c898381efd5 100644
--- a/.github/workflows/sigbuild-docker-branch.yml
+++ b/.github/workflows/sigbuild-docker-branch.yml
@@ -31,23 +31,23 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [python3.7, python3.8, python3.9, python3.10]
+ python-version: [python3.8, python3.9, python3.10, python3.11]
steps:
-
name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
-
name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v1
+ uses: docker/setup-buildx-action@8c0edbc76e98fa90f69d9a2c020dcb50019dc325 # v2.2.1
-
name: Login to DockerHub
- uses: docker/login-action@v1
+ uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Login to GCR
- uses: docker/login-action@v1
+ uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0
with:
registry: gcr.io
username: _json_key
@@ -55,14 +55,14 @@ jobs:
-
name: Generate variables for cache busting and tag naming
run: |
- echo "::set-output name=DATE::$(date +'%Y-%m-%d')"
+ echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT"
# Converts r2.9 to just 2.9
- echo "::set-output name=REF::$(echo $GITHUB_REF_NAME | sed 's/r//g')"
+ echo "REF=$(echo $GITHUB_REF_NAME | sed 's/r//g')" >> "$GITHUB_OUTPUT"
id: vars
-
name: Build and push
id: docker_build
- uses: docker/build-push-action@v2
+ uses: docker/build-push-action@c56af957549030174b10d6867f20e78cfd7debc5 # v3.2.0
with:
push: true
context: ./tensorflow/tools/tf_sig_build_dockerfiles
diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml
index c77c0d66311..065fd91319e 100644
--- a/.github/workflows/sigbuild-docker-presubmit.yml
+++ b/.github/workflows/sigbuild-docker-presubmit.yml
@@ -29,18 +29,18 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [python3.7, python3.8, python3.9, python3.10]
+ python-version: [python3.8, python3.9, python3.10, python3.11]
steps:
-
name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
-
name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v1
+ uses: docker/setup-buildx-action@8c0edbc76e98fa90f69d9a2c020dcb50019dc325 # v2.2.1
-
name: Login to GCR
if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging')
- uses: docker/login-action@v1
+ uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0
with:
registry: gcr.io
username: _json_key
@@ -48,12 +48,12 @@ jobs:
-
name: Grab the date to do cache busting (assumes same day OK to keep)
run: |
- echo "::set-output name=DATE::$(date +'%Y-%m-%d')"
+ echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT"
id: date
-
name: Build containers, and push to GCR only if the 'build and push to gcr.io for staging' label is applied
id: docker_build
- uses: docker/build-push-action@v2
+ uses: docker/build-push-action@c56af957549030174b10d6867f20e78cfd7debc5 # v3.2.0
with:
push: ${{ contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') }}
context: ./tensorflow/tools/tf_sig_build_dockerfiles
@@ -69,17 +69,17 @@ jobs:
cache-to: type=inline
-
name: Add a comment with the pushed containers
- uses: mshick/add-pr-comment@v1
+ uses: mshick/add-pr-comment@a65df5f64fc741e91c59b8359a4bc56e57aaf5b1 # v2
if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging')
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
message: |
I pushed these containers:
+ - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.11`
- `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.10`
- `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.9`
- `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.8`
- - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.7`
Re-apply the `build and push to gcr.io for staging` label to rebuild and push again. This comment will only be posted once.
-
diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml
index 276a0abc242..c9b12a39076 100644
--- a/.github/workflows/sigbuild-docker.yml
+++ b/.github/workflows/sigbuild-docker.yml
@@ -34,23 +34,23 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [python3.7, python3.8, python3.9, python3.10]
+ python-version: [python3.8, python3.9, python3.10, python3.11]
steps:
-
name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
-
name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v1
+ uses: docker/setup-buildx-action@8c0edbc76e98fa90f69d9a2c020dcb50019dc325 # v2.2.1
-
name: Login to DockerHub
- uses: docker/login-action@v1
+ uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Login to GCR
- uses: docker/login-action@v1
+ uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0
with:
registry: gcr.io
username: _json_key
@@ -61,15 +61,15 @@ jobs:
# [[:digit:]] searches for numbers and \+ joins them together
major_version=$(grep "^#define TF_MAJOR_VERSION" ./tensorflow/core/public/version.h | grep -o "[[:digit:]]\+")
minor_version=$(grep "^#define TF_MINOR_VERSION" ./tensorflow/core/public/version.h | grep -o "[[:digit:]]\+")
- echo ::set-output name=TF_VERSION::${major_version}.${minor_version}
+ echo "TF_VERSION=${major_version}.${minor_version}" >> "$GITHUB_OUTPUT"
# Also get the current date to do cache busting. Assumes one day
# is an ok range for rebuilds
- echo "::set-output name=DATE::$(date +'%Y-%m-%d')"
+ echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT"
id: tf-version
-
name: Build and push
id: docker_build
- uses: docker/build-push-action@v2
+ uses: docker/build-push-action@c56af957549030174b10d6867f20e78cfd7debc5 # v3.2.0
with:
push: true
context: ./tensorflow/tools/tf_sig_build_dockerfiles
diff --git a/.github/workflows/trusted-partners.yml b/.github/workflows/trusted-partners.yml
index abf62dd2b8a..7c2fb863d15 100644
--- a/.github/workflows/trusted-partners.yml
+++ b/.github/workflows/trusted-partners.yml
@@ -30,9 +30,9 @@ jobs:
github.event.sender.type == 'User'
steps:
- name: Checkout repo
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
- name: Trusted-Partners-PR
- uses: actions/github-script@v6
+ uses: actions/github-script@d556feaca394842dc55e4734bf3bb9f685482fa0 # v6.3.3
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@@ -49,6 +49,9 @@ jobs:
case "nvidia.com":
console.log(await script.filter({github, context, domain}));
break;
+ case "linaro.org":
+ console.log(await script.filter({github, context, domain}));
+ break;
case "google.com":
console.log("Googler. No action necessary");
break;
diff --git a/.github/workflows/trusted_partners.js b/.github/workflows/trusted_partners.js
index 6b6de25946e..60de918108d 100644
--- a/.github/workflows/trusted_partners.js
+++ b/.github/workflows/trusted_partners.js
@@ -39,9 +39,9 @@ const get_email_domain = async ({github, username}) => {
return domain;
};
-/** For trusted parters like Intel, we want to auto-run tests and mark the PR as ready to pull
+/** For trusted parters like Intel, we want to auto-run tests
This allows us to reduce the delay to external partners
- Add Labels - kokoro:force-run, ready to pull
+ Add Labels - kokoro:force-run
The PR is also assigned to specific teams to fast track review
Additional reviewers can be added manually based on PR contents
@param {!object}
@@ -50,34 +50,41 @@ const get_email_domain = async ({github, username}) => {
@return {string} Returns the message with labels attached and assignees added
*/
const filter_action = async ({github, context, domain}) => {
- const labels = ['kokoro:force-run', 'ready to pull'];
+ const labels = ['kokoro:force-run'];
let assignees = [];
const title = context.payload.pull_request && context.payload.pull_request.title;
+ const lowercased_title = (title || '').toLowerCase();
const onednn_assignees = ['penpornk'];
- if (title && title.toLowerCase().includes("onednn"))
- assignees = onednn_assignees;
+ if (lowercased_title.includes('onednn')) assignees = onednn_assignees;
const intel_windows_assignees = ['nitins17', 'learning-to-play'];
- if (title && title.toLowerCase().includes('intel') &&
- title.toLowerCase().includes('windows') && domain.includes('intel.com'))
+ if (lowercased_title.includes('intel') &&
+ lowercased_title.includes('windows') && domain.includes('intel.com'))
assignees = intel_windows_assignees;
const apple_silicon_assignees = ['penpornk', 'nitins17'];
- if (title && title.toLowerCase().includes('apple') &&
- title.toLowerCase().includes('silicon') && domain.includes('apple.com'))
+ if (lowercased_title.includes('apple') &&
+ lowercased_title.includes('silicon') && domain.includes('apple.com'))
assignees = apple_silicon_assignees;
- if (title && title.toLowerCase().includes('nvidia') &&
- domain.includes('nvidia.com')) {
- if (title.toLowerCase().includes('jax')) {
+ if (lowercased_title.includes('tf-trt') && domain.includes('nvidia.com')) {
+ assignees.push(
+ 'DEKHTIARJonathan', 'meena-at-work', 'nluehr', 'pjannaty', 'poulsbo');
+ } else if (
+ lowercased_title.includes('nvidia') && domain.includes('nvidia.com')) {
+ if (lowercased_title.includes('jax')) {
assignees.push('hawkinsp', 'yashk2810', 'skye');
}
- if (title.toLowerCase().includes('xla') ||
- title.toLowerCase().includes('gpu')) {
+ if (lowercased_title.includes('xla') || lowercased_title.includes('gpu')) {
assignees.push('cheshire', 'gcforster', 'reedwm', 'chsigg', 'xla-rotation');
}
- if (title.toLowerCase().includes('tf')) {
+ if (lowercased_title.includes('tf')) {
assignees.push('rohan100jain', 'bfontain');
}
}
+ if (lowercased_title.includes('linaro') && domain.includes('linaro.org')) {
+ if (lowercased_title.includes('arm_ci')) {
+ assignees.push('nitins17', 'penpornk');
+ }
+ }
const resp_label = await github.rest.issues.addLabels({
issue_number: context.issue.number,
diff --git a/.github/workflows/update-nightly.yml b/.github/workflows/update-nightly.yml
index 0265ffbebe2..60372fddd27 100644
--- a/.github/workflows/update-nightly.yml
+++ b/.github/workflows/update-nightly.yml
@@ -23,7 +23,7 @@ jobs:
if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks
runs-on: ubuntu-latest
steps:
- - uses: zofrex/mirror-branch@v1
+ - uses: zofrex/mirror-branch@a8809f0b42f9dfe9b2c5c2162a46327c23d15266 # v1.0.3
name: Set nightly branch to master HEAD
with:
target-branch: 'nightly'
diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml
index 2f86ff2b2e5..ce31d59868a 100644
--- a/.github/workflows/update-rbe.yml
+++ b/.github/workflows/update-rbe.yml
@@ -27,7 +27,7 @@ jobs:
if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks
steps:
- name: Checkout code
- uses: actions/checkout@v2
+ uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0
- name: Update the RBE Configs
run: |
function map() {
@@ -48,28 +48,40 @@ jobs:
# See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/toolchains/remote_config/configs.bzl
# This is a mapping of name_container_map keys under sigbuild_tf_configs
# to tag names on gcr.io/tensorflow-sigs/build.
+ # TF 2.9
map sigbuild-r2.9 2.9-python3.9
- map sigbuild-r2.9-python3.7 2.9-python3.7
map sigbuild-r2.9-python3.8 2.9-python3.8
map sigbuild-r2.9-python3.9 2.9-python3.9
map sigbuild-r2.9-python3.10 2.9-python3.10
+ # TF 2.10
map sigbuild-r2.10 2.10-python3.9
- map sigbuild-r2.10-python3.7 2.10-python3.7
map sigbuild-r2.10-python3.8 2.10-python3.8
map sigbuild-r2.10-python3.9 2.10-python3.9
map sigbuild-r2.10-python3.10 2.10-python3.10
- map sigbuild-128 128-python3.9
- map sigbuild-128-python3.7 128-python3.7
- map sigbuild-128-python3.8 128-python3.8
- map sigbuild-128-python3.9 128-python3.9
- map sigbuild-128-python3.10 128-python3.10
+ # TF 2.11
map sigbuild-r2.11 2.11-python3.9
- map sigbuild-r2.11-python3.7 2.11-python3.7
map sigbuild-r2.11-python3.8 2.11-python3.8
map sigbuild-r2.11-python3.9 2.11-python3.9
- map sigbuild-r2.11-python3.11 2.11-python3.10
+ map sigbuild-r2.11-python3.10 2.11-python3.10
+ # WIP Clang Containers, used by TVCs
+ map sigbuild-57469 57469-python3.9
+ map sigbuild-57469-python3.8 57469-python3.8
+ map sigbuild-57469-python3.9 57469-python3.9
+ map sigbuild-57469-python3.10 57469-python3.10
+ # TF 2.12
+ map sigbuild-r2.12 2.12-python3.9
+ map sigbuild-r2.12-python3.8 2.12-python3.8
+ map sigbuild-r2.12-python3.9 2.12-python3.9
+ map sigbuild-r2.12-python3.10 2.12-python3.10
+ map sigbuild-r2.12-python3.11 2.12-python3.11
+ # TF 2.12 + Clang (containers are the same, but env vars in configs.bzl are different)
+ map sigbuild-r2.12-clang 2.12-python3.9
+ map sigbuild-r2.12-clang-python3.8 2.12-python3.8
+ map sigbuild-r2.12-clang-python3.9 2.12-python3.9
+ map sigbuild-r2.12-clang-python3.10 2.12-python3.10
+ map sigbuild-r2.12-clang-python3.11 2.12-python3.11
- name: Create Pull Request with changes
- uses: peter-evans/create-pull-request@v3
+ uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3
with:
title: Update the RBE images to the latest container versions
committer: TensorFlow Release Automation
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 01e20da7c87..ccc170b5c6e 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -243,7 +243,7 @@ There are two ways to run TensorFlow unit tests.
For a single component e.g. softmax op:
```bash
- bazel test ${flags} tensorflow/python/kernel_tests:softmax_op_test
+ bazel test ${flags} tensorflow/python/kernel_tests/nn_ops:softmax_op_test
```
For a single/parameterized test e.g. `test_capture_variables` in
diff --git a/README.md b/README.md
index 73e75c1df81..c94227d26d7 100644
--- a/README.md
+++ b/README.md
@@ -104,6 +104,19 @@ for general questions and discussion, and please direct specific questions to
The TensorFlow project strives to abide by generally accepted best practices in
open-source software development.
+## Patching guidelines
+
+Follow these steps to patch a specific version of TensorFlow, for example, to
+apply fixes to bugs or security vulnerabilities:
+
+* Clone the TensorFlow repo and switch to the corresponding branch for your
+ desired TensorFlow version, for example, branch `r2.8` for version 2.8.
+* Apply (that is, cherry pick) the desired changes and resolve any code
+ conflicts.
+* Run TensorFlow tests and ensure they pass.
+* [Build](https://www.tensorflow.org/install/source) the TensorFlow pip
+ package from source.
+
## Continuous build status
You can find more community-supported platforms and configurations in the
diff --git a/RELEASE.md b/RELEASE.md
index 40320f2a172..ea4ab08237e 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,18 +1,114 @@
-# Release 2.12.0
+# Release 2.13.0
-* `tf.keras`:
+# Breaking Changes
- * Added `jit_compile` as a settable property to `tf.keras.Model`.
- * Added `synchronized` optional parameter to `layers.BatchNormalization`.
- * Added deprecation warning to
- `layers.experimental.SyncBatchNormalization` and suggested to use
- `layers.BatchNormalization` with `synchronized=True` instead.
+*
+*
+
+# Known Caveats
+
+*
+*
+*
+
+# Major Features and Improvements
+
+* `tf.lite`:
+
+ * Add 16-bit and 64-bit float type support for built-in op `cast`.
+
+* `tf.keras`
+
+ * Added Keras metrics `tf.keras.metrics.FBetaScore` and
+ `tf.keras.metrics.F1Score`.
+
+# Bug Fixes and Other Changes
+
+*
+*
+*
+
+# Thanks to our Contributors
+
+This release contains contributions from many people at Google, as well as:
+
+, , , , ,
+
+
+# Release 2.12.0
# Breaking Changes
*
*
+* Build, Compilation and Packaging
+
+ * Removal of redundant packages: the `tensorflow-gpu` and `tf-nightly-gpu`
+ packages have been effectively removed and replaced with packages that
+ direct users to switch to `tensorflow` or `tf-nightly` respectively.
+ The naming difference was the only difference between the two sets of
+ packages ever since TensorFlow 2.1, so there is no loss of functionality
+ or GPU support. See
+ https://pypi.org/project/tensorflow-gpu for more details.
+
+* `tf.function`:
+
+ * tf.function now uses the Python inspect library directly for parsing
+ the signature of the Python function it is decorated on.
+ * This can break certain cases that were previously ignored where the
+ signature is malformed, e.g.
+ * Using functools.wraps on a function with different signature
+ * Using functools.partial with an invalid tf.function input
+ * tf.function now enforces input parameter names to be valid Python
+ identifiers. Incompatible names are automatically sanitized similarly to
+ existing SavedModel signature behavior.
+ * Parameterless tf.functions are assumed to have an empty input_signature
+ instead of an undefined one even if the input_signature is unspecified.
+ * tf.types.experimental.TraceType now requires an additional
+ `placeholder_value` method to be defined.
+ * tf.function now traces with placeholder values generated by TraceType
+ instead of the value itself.
+
+* `tf.config.experimental.enable_mlir_graph_optimization`:
+
+ * Experimental API removed.
+
+* `tf.config.experimental.disable_mlir_graph_optimization`:
+
+ * Experimental API removed.
+
+* `tf.keras`
+
+ * Moved all saving-related utilities to a new namespace, `keras.saving`,
+ i.e. `keras.saving.load_model`, `keras.saving.save_model`,
+ `keras.saving.custom_object_scope`, `keras.saving.get_custom_objects`,
+ `keras.saving.register_keras_serializable`,
+ `keras.saving.get_registered_name` and
+ `keras.saving.get_registered_object`.
+ The previous API locations (in `keras.utils` and `keras.models`) will
+ stay available indefinitely, but we recommend that you update your code
+ to point to the new API locations.
+ * Improvements and fixes in Keras loss masking:
+ * Whether you represent a ragged tensor as a `tf.RaggedTensor` or using
+ [keras masking](https://www.tensorflow.org/guide/keras/masking_and_padding),
+ the returned loss values should be the identical to each other.
+ In previous versions Keras may have silently ignored the mask.
+ * If you use masked losses with Keras the loss values may be different
+ in TensorFlow `2.12` compared to previous versions.
+ * In cases where the mask was previously ignored, you will now get
+ an error if you pass a mask with an incompatible shape.
+
+* `tf.SavedModel`
+
+ * Introduce new class `tf.saved_model.experimental.Fingerprint` that
+ contains the fingerprint of the SavedModel. See the
+ [SavedModel Fingerprinting RFC](https://github.com/tensorflow/community/pull/415)
+ for details.
+ * Introduce API `tf.saved_model.experimental.read_fingerprint(export_dir)`
+ for reading the fingerprint of a SavedModel.
+
+
# Known Caveats
*
@@ -25,13 +121,90 @@
* Add 16-bit float type support for built-in op `fill`.
* Transpose now supports 6D tensors.
+ * Float LSTM now supports diagonal recurrent tensors:
+ https://arxiv.org/abs/1903.08023
* `tf.keras`:
+ * The new Keras model saving format (`.keras`) is available. You can start
+ using it via `model.save(f"{fname}.keras", save_format="keras_v3")`. In
+ the future it will become the default for all files with the `.keras`
+ extension. This file format targets the Python runtime only and makes
+ it possible to reload Python objects identical to the saved originals.
+ The format supports non-numerical state such as vocabulary files and
+ lookup tables, and it is easy to customize in the case of custom layers
+ with exotic elements of state (e.g. a FIFOQueue). The format
+ does not rely on bytecode or pickling, and is safe by default. Note
+ that as a result, Python `lambdas` are disallowed at loading time. If
+ you want to use `lambdas`, you can pass `safe_mode=False` to the loading
+ method (only do this if you trust the source of the model).
+ * Added a `model.export(filepath)` API to create a lightweight SavedModel
+ artifact that can be used for inference (e.g. with TF-Serving).
+ * Added `keras.export.ExportArchive` class for low-level customization of
+ the process of exporting SavedModel artifacts for inference.
+ Both ways of exporting models are based on `tf.function` tracing
+ and produce a TF program composed of TF ops. They are meant primarily
+ for environments where the TF runtime is available,
+ but not the Python interpreter, as is typical
+ for production with TF Serving.
+ * Added utility `tf.keras.utils.FeatureSpace`, a one-stop shop for
+ structured data preprocessing and encoding.
* Added `tf.SparseTensor` input support to `tf.keras.layers.Embedding`
layer. The layer now accepts a new boolean argument `sparse`. If
`sparse` is set to True, the layer returns a SparseTensor instead of a
dense Tensor. Defaults to False.
+ * Added `jit_compile` as a settable property to `tf.keras.Model`.
+ * Added `synchronized` optional parameter to `layers.BatchNormalization`.
+ * Added deprecation warning to
+ `layers.experimental.SyncBatchNormalization` and suggested to use
+ `layers.BatchNormalization` with `synchronized=True` instead.
+ * Updated `tf.keras.layers.BatchNormalization` to support masking of the
+ inputs (`mask` argument) when computing the mean and variance.
+ * Add `tf.keras.layers.Identity`, a placeholder pass-through layer.
+ * Add `show_trainable` option to `tf.keras.utils.model_to_dot` to display
+ layer trainable status in model plots.
+ * Add ability to save a `tf.keras.utils.FeatureSpace` object, via
+ `feature_space.save("myfeaturespace.keras")`, and reload it via
+ `feature_space = tf.keras.models.load_model("myfeaturespace.keras")`.
+ * Added utility `tf.keras.utils.to_ordinal` to convert class vector to
+ ordinal regression / classification matrix.
+
+* `tf.experimental.dtensor`:
+
+ * Coordination service now works with
+ `dtensor.initialize_accelerator_system`, and enabled by default.
+ * Add `tf.experimental.dtensor.is_dtensor` to check if a tensor is a
+ DTensor instance.
+
+* `tf.data`:
+
+ * Added support for alternative checkpointing protocol which makes it
+ possible to checkpoint the state of the input pipeline without having to
+ store the contents of internal buffers. The new functionality can be
+ enabled through the `experimental_symbolic_checkpoint` option of
+ `tf.data.Options()`.
+ * Added a new `rerandomize_each_iteration` argument for the
+ `tf.data.Dataset.random()` operation, which controls whether the
+ sequence of generated random numbers should be re-randomized every epoch
+ or not (the default behavior). If `seed` is set and
+ `rerandomize_each_iteration=True`, the `random()` operation will produce
+ a different (deterministic) sequence of numbers every epoch.
+ * Added a new `rerandomize_each_iteration` argument for the
+ `tf.data.Dataset.sample_from_datasets()` operation, which controls
+ whether the sequence of generated random numbers used for sampling
+ should be re-randomized every epoch or not. If `seed` is set and
+ `rerandomize_each_iteration=True`, the `sample_from_datasets()`
+ operation will use a different (deterministic) sequence of numbers every
+ epoch.
+
+* `tf.test`:
+
+ * Added `tf.test.experimental.sync_devices`, which is useful for
+ accurately measuring performance in benchmarks.
+
+* `tf.experimental.dtensor`:
+
+ * Added experimental support to ReduceScatter fuse on GPU (NCCL).
# Bug Fixes and Other Changes
@@ -39,6 +212,29 @@
*
*
+* `tf.random`
+ * Added non-experimental aliases for `tf.random.split` and
+ `tf.random.fold_in`, the experimental endpoints are still available
+ so no code changes are necessary.
+* `tf.experimental.ExtensionType`
+ * Added function `experimental.extension_type.as_dict()`, which converts an
+ instance of `tf.experimental.ExtensionType` to a `dict` representation.
+* `stream_executor`
+ * Top level `stream_executor` directory has been deleted, users should use
+ equivalent headers and targets under `compiler/xla/stream_executor`.
+* `tf.nn`
+ * Added `tf.nn.experimental.general_dropout`, which is similar to
+ `tf.random.experimental.stateless_dropout` but accepts a custom sampler
+ function.
+* `tf.types.experimental.GenericFunction`
+ * The `experimental_get_compiler_ir` method supports tf.TensorSpec
+ compilation arguments.
+* `tf.config.experimental.mlir_bridge_rollout`
+ * Removed enums `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` and
+ `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED` which are no longer used by
+ the tf2xla bridge
+
+
# Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
@@ -47,12 +243,6 @@ This release contains contributions from many people at Google, as well as:
# Release 2.11.0
-
-
-* `StatusOr::ConsumeValueOrDie` and `StatusOr::ValueOrDie`, both deprecated in
- TF 2.10 has been removed.
-
-
## Breaking Changes
* `tf.keras.optimizers.Optimizer` now points to the new Keras optimizer, and
old optimizers have moved to the `tf.keras.optimizers.legacy` namespace.
@@ -106,12 +296,6 @@ This release contains contributions from many people at Google, as well as:
only be implemented based on `tf.keras.optimizers.Optimizer`, the new
base class.
-## Known Caveats
-
-*
-*
-*
-
## Major Features and Improvements
* `tf.lite`:
@@ -160,7 +344,7 @@ This release contains contributions from many people at Google, as well as:
file is a protobuf containing the "fingerprint" of the SavedModel. See
the [RFC](https://github.com/tensorflow/community/pull/415) for more
details regarding its design and properties.
-
+
* `tf.data`:
* Graduated experimental APIs:
* [`tf.data.Dataset.ragged_batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset/#ragged_batch), which batches elements of `tf.data.Dataset`s into `tf.RaggedTensor`s.
@@ -185,11 +369,152 @@ This release contains contributions from many people at Google, as well as:
* `tf.SparseTensor`:
* Introduced `set_shape`, which sets the static dense shape of the sparse tensor and has the same semantics as `tf.Tensor.set_shape`.
+## Security
+
+* TF is currently using giflib 5.2.1 which has [CVE-2022-28506](https://nvd.nist.gov/vuln/detail/CVE-2022-28506). TF is not affected by the CVE as it does not use `DumpScreen2RGB` at all.
+* Fixes an OOB seg fault in `DynamicStitch` due to missing validation ([CVE-2022-41883](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41883))
+* Fixes an overflow in `tf.keras.losses.poisson` ([CVE-2022-41887](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41887))
+* Fixes a heap OOB failure in `ThreadUnsafeUnigramCandidateSampler` caused by missing validation ([CVE-2022-41880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41880))
+* Fixes a segfault in `ndarray_tensor_bridge` ([CVE-2022-41884](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41884))
+* Fixes an overflow in `FusedResizeAndPadConv2D` ([CVE-2022-41885](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41885))
+* Fixes a overflow in `ImageProjectiveTransformV2` ([CVE-2022-41886](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41886))
+* Fixes an FPE in `tf.image.generate_bounding_box_proposals` on GPU ([CVE-2022-41888](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41888))
+* Fixes a segfault in `pywrap_tfe_src` caused by invalid attributes ([CVE-2022-41889](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41889))
+* Fixes a `CHECK` fail in `BCast` ([CVE-2022-41890](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41890))
+* Fixes a segfault in `TensorListConcat` ([CVE-2022-41891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41891))
+* Fixes a `CHECK_EQ` fail in `TensorListResize` ([CVE-2022-41893](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41893))
+* Fixes an overflow in `CONV_3D_TRANSPOSE` on TFLite ([CVE-2022-41894](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41894))
+* Fixes a heap OOB in `MirrorPadGrad` ([CVE-2022-41895](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41895))
+* Fixes a crash in `Mfcc` ([CVE-2022-41896](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41896))
+* Fixes a heap OOB in `FractionalMaxPoolGrad` ([CVE-2022-41897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41897))
+* Fixes a `CHECK` fail in `SparseFillEmptyRowsGrad` ([CVE-2022-41898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41898))
+* Fixes a `CHECK` fail in `SdcaOptimizer` ([CVE-2022-41899](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41899))
+* Fixes a heap OOB in `FractionalAvgPool` and `FractionalMaxPool`([CVE-2022-41900](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41900))
+* Fixes a `CHECK_EQ` in `SparseMatrixNNZ` ([CVE-2022-41901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41901))
+* Fixes an OOB write in grappler ([CVE-2022-41902](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41902))
+* Fixes a overflow in `ResizeNearestNeighborGrad` ([CVE-2022-41907](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41907))
+* Fixes a `CHECK` fail in `PyFunc` ([CVE-2022-41908](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41908))
+* Fixes a segfault in `CompositeTensorVariantToComponents` ([CVE-2022-41909](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41909))
+* Fixes a invalid char to bool conversion in printing a tensor ([CVE-2022-41911](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41911))
+* Fixes a heap overflow in `QuantizeAndDequantizeV2` ([CVE-2022-41910](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41910))
+* Fixes a `CHECK` failure in `SobolSample` via missing validation ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935))
+* Fixes a `CHECK` fail in `TensorListScatter` and `TensorListScatterV2` in eager mode ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935))
+
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
-, , , , ,
+103yiran, 8bitmp3, Aakar Dwivedi, Alexander Grund, alif_elham, Aman Agarwal,
+amoitra, Andrei Ivanov, andreii, Andrew Goodbody, angerson, Ashay Rane,
+Azeem Shaikh, Ben Barsdell, bhack, Bhavani Subramanian, Cedric Nugteren,
+Chandra Kumar Ramasamy, Christopher Bate, CohenAriel, Cotarou, cramasam,
+Enrico Minack, Francisco Unda, Frederic Bastien, gadagashwini, Gauri1 Deshpande,
+george, Jake, Jeff, Jerry Ge, Jingxuan He, Jojimon Varghese, Jonathan Dekhtiar,
+Kaixi Hou, Kanvi Khanna, kcoul, Keith Smiley, Kevin Hu, Kun Lu, kushanam,
+Lianmin Zheng, liuyuanqiang, Louis Sugy, Mahmoud Abuzaina, Marius Brehler,
+mdfaijul, Meenakshi Venkataraman, Milos Puzovic, mohantym, Namrata-Ibm,
+Nathan John Sircombe, Nathan Luehr, Olaf Lipinski, Om Thakkar, Osman F Bayram,
+Patrice Vignola, Pavani Majety, Philipp Hack, Prianka Liz Kariat, Rahul Batra,
+RajeshT, Renato Golin, riestere, Roger Iyengar, Rohit Santhanam, Rsanthanam-Amd,
+Sadeed Pv, Samuel Marks, Shimokawa, Naoaki, Siddhesh Kothadi, Simengliu-Nv,
+Sindre Seppola, snadampal, Srinivasan Narayanamoorthy, sushreebarsa,
+syedshahbaaz, Tamas Bela Feher, Tatwai Chong, Thibaut Goetghebuer-Planchon,
+tilakrayal, Tom Anderson, Tomohiro Endo, Trevor Morris, vibhutisawant,
+Victor Zhang, Vremold, Xavier Bonaventura, Yanming Wang, Yasir Modak,
+Yimei Sun, Yong Tang, Yulv-Git, zhuoran.liu, zotanika
+
+# Release 2.10.1
+
+This release introduces several vulnerability fixes:
+
+* Fixes an OOB seg fault in `DynamicStitch` due to missing validation ([CVE-2022-41883](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41883))
+* Fixes an overflow in `tf.keras.losses.poisson` ([CVE-2022-41887](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41887))
+* Fixes a heap OOB failure in `ThreadUnsafeUnigramCandidateSampler` caused by missing validation ([CVE-2022-41880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41880))
+* Fixes a segfault in `ndarray_tensor_bridge` ([CVE-2022-41884](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41884))
+* Fixes an overflow in `FusedResizeAndPadConv2D` ([CVE-2022-41885](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41885))
+* Fixes a overflow in `ImageProjectiveTransformV2` ([CVE-2022-41886](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41886))
+* Fixes an FPE in `tf.image.generate_bounding_box_proposals` on GPU ([CVE-2022-41888](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41888))
+* Fixes a segfault in `pywrap_tfe_src` caused by invalid attributes ([CVE-2022-41889](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41889))
+* Fixes a `CHECK` fail in `BCast` ([CVE-2022-41890](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41890))
+* Fixes a segfault in `TensorListConcat` ([CVE-2022-41891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41891))
+* Fixes a `CHECK_EQ` fail in `TensorListResize` ([CVE-2022-41893](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41893))
+* Fixes an overflow in `CONV_3D_TRANSPOSE` on TFLite ([CVE-2022-41894](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41894))
+* Fixes a heap OOB in `MirrorPadGrad` ([CVE-2022-41895](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41895))
+* Fixes a crash in `Mfcc` ([CVE-2022-41896](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41896))
+* Fixes a heap OOB in `FractionalMaxPoolGrad` ([CVE-2022-41897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41897))
+* Fixes a `CHECK` fail in `SparseFillEmptyRowsGrad` ([CVE-2022-41898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41898))
+* Fixes a `CHECK` fail in `SdcaOptimizer` ([CVE-2022-41899](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41899))
+* Fixes a heap OOB in `FractionalAvgPool` and `FractionalMaxPool`([CVE-2022-41900](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41900))
+* Fixes a `CHECK_EQ` in `SparseMatrixNNZ` ([CVE-2022-41901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41901))
+* Fixes an OOB write in grappler ([CVE-2022-41902](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41902))
+* Fixes a overflow in `ResizeNearestNeighborGrad` ([CVE-2022-41907](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41907))
+* Fixes a `CHECK` fail in `PyFunc` ([CVE-2022-41908](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41908))
+* Fixes a segfault in `CompositeTensorVariantToComponents` ([CVE-2022-41909](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41909))
+* Fixes a invalid char to bool conversion in printing a tensor ([CVE-2022-41911](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41911))
+* Fixes a heap overflow in `QuantizeAndDequantizeV2` ([CVE-2022-41910](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41910))
+* Fixes a `CHECK` failure in `SobolSample` via missing validation ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935))
+* Fixes a `CHECK` fail in `TensorListScatter` and `TensorListScatterV2` in eager mode ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935))
+
+# Release 2.9.3
+
+This release introduces several vulnerability fixes:
+
+* Fixes an overflow in `tf.keras.losses.poisson` ([CVE-2022-41887](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41887))
+* Fixes a heap OOB failure in `ThreadUnsafeUnigramCandidateSampler` caused by missing validation ([CVE-2022-41880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41880))
+* Fixes a segfault in `ndarray_tensor_bridge` ([CVE-2022-41884](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41884))
+* Fixes an overflow in `FusedResizeAndPadConv2D` ([CVE-2022-41885](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41885))
+* Fixes a overflow in `ImageProjectiveTransformV2` ([CVE-2022-41886](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41886))
+* Fixes an FPE in `tf.image.generate_bounding_box_proposals` on GPU ([CVE-2022-41888](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41888))
+* Fixes a segfault in `pywrap_tfe_src` caused by invalid attributes ([CVE-2022-41889](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41889))
+* Fixes a `CHECK` fail in `BCast` ([CVE-2022-41890](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41890))
+* Fixes a segfault in `TensorListConcat` ([CVE-2022-41891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41891))
+* Fixes a `CHECK_EQ` fail in `TensorListResize` ([CVE-2022-41893](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41893))
+* Fixes an overflow in `CONV_3D_TRANSPOSE` on TFLite ([CVE-2022-41894](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41894))
+* Fixes a heap OOB in `MirrorPadGrad` ([CVE-2022-41895](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41895))
+* Fixes a crash in `Mfcc` ([CVE-2022-41896](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41896))
+* Fixes a heap OOB in `FractionalMaxPoolGrad` ([CVE-2022-41897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41897))
+* Fixes a `CHECK` fail in `SparseFillEmptyRowsGrad` ([CVE-2022-41898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41898))
+* Fixes a `CHECK` fail in `SdcaOptimizer` ([CVE-2022-41899](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41899))
+* Fixes a heap OOB in `FractionalAvgPool` and `FractionalMaxPool`([CVE-2022-41900](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41900))
+* Fixes a `CHECK_EQ` in `SparseMatrixNNZ` ([CVE-2022-41901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41901))
+* Fixes an OOB write in grappler ([CVE-2022-41902](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41902))
+* Fixes a overflow in `ResizeNearestNeighborGrad` ([CVE-2022-41907](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41907))
+* Fixes a `CHECK` fail in `PyFunc` ([CVE-2022-41908](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41908))
+* Fixes a segfault in `CompositeTensorVariantToComponents` ([CVE-2022-41909](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41909))
+* Fixes a invalid char to bool conversion in printing a tensor ([CVE-2022-41911](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41911))
+* Fixes a heap overflow in `QuantizeAndDequantizeV2` ([CVE-2022-41910](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41910))
+* Fixes a `CHECK` failure in `SobolSample` via missing validation ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935))
+* Fixes a `CHECK` fail in `TensorListScatter` and `TensorListScatterV2` in eager mode ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935))
+
+# Release 2.8.4
+
+This release introduces several vulnerability fixes:
+
+* Fixes a heap OOB failure in `ThreadUnsafeUnigramCandidateSampler` caused by missing validation ([CVE-2022-41880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41880))
+* Fixes a segfault in `ndarray_tensor_bridge` ([CVE-2022-41884](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41884))
+* Fixes an overflow in `FusedResizeAndPadConv2D` ([CVE-2022-41885](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41885))
+* Fixes a overflow in `ImageProjectiveTransformV2` ([CVE-2022-41886](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41886))
+* Fixes an FPE in `tf.image.generate_bounding_box_proposals` on GPU ([CVE-2022-41888](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41888))
+* Fixes a segfault in `pywrap_tfe_src` caused by invalid attributes ([CVE-2022-41889](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41889))
+* Fixes a `CHECK` fail in `BCast` ([CVE-2022-41890](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41890))
+* Fixes a segfault in `TensorListConcat` ([CVE-2022-41891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41891))
+* Fixes a `CHECK_EQ` fail in `TensorListResize` ([CVE-2022-41893](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41893))
+* Fixes an overflow in `CONV_3D_TRANSPOSE` on TFLite ([CVE-2022-41894](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41894))
+* Fixes a heap OOB in `MirrorPadGrad` ([CVE-2022-41895](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41895))
+* Fixes a crash in `Mfcc` ([CVE-2022-41896](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41896))
+* Fixes a heap OOB in `FractionalMaxPoolGrad` ([CVE-2022-41897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41897))
+* Fixes a `CHECK` fail in `SparseFillEmptyRowsGrad` ([CVE-2022-41898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41898))
+* Fixes a `CHECK` fail in `SdcaOptimizer` ([CVE-2022-41899](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41899))
+* Fixes a heap OOB in `FractionalAvgPool` and `FractionalMaxPool`([CVE-2022-41900](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41900))
+* Fixes a `CHECK_EQ` in `SparseMatrixNNZ` ([CVE-2022-41901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41901))
+* Fixes an OOB write in grappler ([CVE-2022-41902](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41902))
+* Fixes a overflow in `ResizeNearestNeighborGrad` ([CVE-2022-41907](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41907))
+* Fixes a `CHECK` fail in `PyFunc` ([CVE-2022-41908](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41908))
+* Fixes a segfault in `CompositeTensorVariantToComponents` ([CVE-2022-41909](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41909))
+* Fixes a invalid char to bool conversion in printing a tensor ([CVE-2022-41911](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41911))
+* Fixes a heap overflow in `QuantizeAndDequantizeV2` ([CVE-2022-41910](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41910))
+* Fixes a `CHECK` failure in `SobolSample` via missing validation ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935))
+* Fixes a `CHECK` fail in `TensorListScatter` and `TensorListScatterV2` in eager mode ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935))
# Release 2.10.0
@@ -10654,3 +10979,5 @@ answered questions, and were part of inspiring discussions.
# Release 0.5.0
Initial release of TensorFlow.
+
+
diff --git a/SECURITY.md b/SECURITY.md
index f6d414794c0..d6d47c4e635 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -273,21 +273,11 @@ TensorFlow uses the following disclosure process:
* An advisory is prepared (but not published) which details the problem and
steps for mitigation.
* The vulnerability is fixed and potential workarounds are identified.
-* We will attempt to cherry-pick the fix to the release branches used for all
- releases of TensorFlow that are at most one year old (though sometimes we
- might not patch all of them). The cherry-picks will occur as soon as possible
- and the patch releases will come at the same time as the next quarterly
- release.
-* Whenever patch releases are finalized, we will notify discuss@tensorflow.org.
* We will publish a security advisory for all fixed vulnerabilities.
For each vulnerability, we try to ingress it as soon as possible, given the size
of the team and the number of reports. Vulnerabilities will, in general, be
-batched to be fixed at the same time as a quarterly release. An exception to
-this rule is for high impact vulnerabilities where exploitation of models used
-for inference in products (i.e., not models created just to showcase a
-vulnerability) is possible. In these cases, we will attempt to do patch releases
-within an accelerated timeline, not waiting for the next quarterly release.
+batched to be fixed at the same time as a quarterly release.
Past security advisories are listed
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).
diff --git a/configure.py b/configure.py
index 135001ed103..6abde63a28a 100644
--- a/configure.py
+++ b/configure.py
@@ -36,7 +36,7 @@
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_SUPPORTED_ANDROID_NDK_VERSIONS = [
- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21
+ 19, 20, 21
]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
@@ -619,7 +619,7 @@ def prompt_loop_or_load_from_env(environ_cp,
'Assuming to be a scripting mistake.' %
(var_name, n_ask_attempts))
- if resolve_symlinks and os.path.islink(val):
+ if resolve_symlinks:
val = os.path.realpath(val)
environ_cp[var_name] = val
return val
@@ -718,7 +718,8 @@ def valid_build_tools(version):
def get_ndk_api_level(environ_cp, android_ndk_home_path):
- """Gets the appropriate NDK API level to use for the provided Android NDK path."""
+ """Gets the appropriate NDK API level to use for the provided Android NDK path.
+ """
# First check to see if we're using a blessed version of the NDK.
properties_path = '%s/source.properties' % android_ndk_home_path
@@ -756,7 +757,7 @@ def valid_api_level(api_level):
android_ndk_api_level = prompt_loop_or_load_from_env(
environ_cp,
var_name='ANDROID_NDK_API_LEVEL',
- var_default='21', # 21 is required for ARM64 support.
+ var_default='26', # 26 is required to support AHardwareBuffer.
ask_for_var=('Please specify the (min) Android NDK API level to use. '
'[Available levels: %s]') % api_levels,
check_success=valid_api_level,
@@ -1188,6 +1189,9 @@ def main():
gcc_env = get_gcc_compiler(environ_cp)
if gcc_env is not None:
+ # Use gold linker if 'gcc' and if 'ppc64le'
+ write_to_bazelrc('build --linkopt="-fuse-ld=gold"')
+
# Get the linker version
ld_version = run_shell([gcc_env, '-Wl,-version']).split()
@@ -1215,8 +1219,6 @@ def main():
if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')):
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
- write_action_env_to_bazelrc('ROCBLAS_TENSILE_LIBPATH',
- environ_cp.get('ROCM_PATH') + '/lib/library')
if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('HIP_PLATFORM')):
write_action_env_to_bazelrc('HIP_PLATFORM', environ_cp.get('HIP_PLATFORM'))
diff --git a/fuzztest.bazelrc b/fuzztest.bazelrc
new file mode 100644
index 00000000000..360b3484ee9
--- /dev/null
+++ b/fuzztest.bazelrc
@@ -0,0 +1,47 @@
+### DO NOT EDIT. Generated file.
+#
+# To regenerate, run the following from your project's workspace:
+#
+# bazel run @com_google_fuzztest//bazel:setup_configs > fuzztest.bazelrc
+#
+# And don't forget to add the following to your project's .bazelrc:
+#
+# try-import %workspace%/fuzztest.bazelrc
+
+
+### Common options.
+#
+# Do not use directly.
+
+# Link with Address Sanitizer (ASAN).
+build:fuzztest-common --linkopt=-fsanitize=address
+
+# Standard define for "ifdef-ing" any fuzz test specific code.
+build:fuzztest-common --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
+
+# In fuzz tests, we want to catch assertion violations even in optimized builds.
+build:fuzztest-common --copt=-UNDEBUG
+
+# Enable libc++ assertions.
+# See https://libcxx.llvm.org/UsingLibcxx.html#enabling-the-safe-libc-mode
+build:fuzztest-common --copt=-D_LIBCPP_ENABLE_ASSERTIONS=1
+
+
+### FuzzTest build configuration.
+#
+# Use with: --config=fuzztest
+
+build:fuzztest --config=fuzztest-common
+
+# Link statically.
+build:fuzztest --dynamic_mode=off
+
+# We rely on the following flag instead of the compiler provided
+# __has_feature(address_sanitizer) to know that we have an ASAN build even in
+# the uninstrumented runtime.
+build:fuzztest --copt=-DADDRESS_SANITIZER
+
+# We apply coverage tracking and ASAN instrumentation to everything but the
+# FuzzTest framework itself (including GoogleTest and GoogleMock).
+build:fuzztest --per_file_copt=+//,-//fuzztest:,-googletest/.*,-googlemock/.*@-fsanitize=address,-fsanitize-coverage=inline-8bit-counters,-fsanitize-coverage=trace-cmp
+
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 46879082f93..0d27a8294f5 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -8,16 +8,26 @@ load(
"//tensorflow:tensorflow.bzl",
"VERSION",
"VERSION_MAJOR",
+ "check_deps",
"if_google",
"if_oss",
+ "if_xla_available",
"tf_cc_shared_object",
"tf_custom_op_library_additional_deps_impl",
+ "tf_monitoring_python_deps",
"tf_native_cc_binary",
+ "tsl_async_value_deps",
)
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_additional_binary_deps",
)
+load(
+ "//tensorflow/core/platform:build_config_root.bzl",
+ "if_static",
+ "tf_additional_plugin_deps",
+ "tf_additional_profiler_deps",
+)
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl_ml",
@@ -28,6 +38,7 @@ load(
"ADDITIONAL_API_INDEXABLE_SETTINGS",
"tf_cc_shared_library",
)
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# copybara:uncomment_begin
# load("//tools/build_defs/license:license.bzl", "license")
@@ -95,12 +106,25 @@ PACKAGE_STATIC_DEPS = [
"@mkl_dnn_acl_compatible//:__subpackages__",
"@mkl_dnn_v1//:__subpackages__",
"@nccl_archive//:__subpackages__",
+ "@nvtx_archive//:__subpackages__",
"@org_sqlite//:__subpackages__",
"@platforms//:__subpackages__",
"@snappy//:__subpackages__",
"@upb//:__subpackages__",
"@zlib//:__subpackages__",
-]
+ "@dlpack//:__subpackages__",
+ "@arm_neon_2_x86_sse//:__subpackages__",
+ "@cpuinfo//:__subpackages__",
+ "@ruy//:__subpackages__",
+ "@XNNPACK//:__subpackages__",
+ "@pthreadpool//:__subpackages__",
+ "@FXdiv//:__subpackages__",
+ "@FP16//:__subpackages__",
+ "@clog//:__subpackages__",
+ "@flatbuffers//:__subpackages__",
+ "@nccl_archive//:__subpackages__",
+ "@triton//:__subpackages__",
+] + tsl_async_value_deps()
package(
# copybara:uncomment default_applicable_licenses = [":license"],
@@ -918,23 +942,21 @@ config_setting(
visibility = ["//visibility:public"],
)
-# copybara:uncomment_begin(configurable API loading)
-# bool_flag(
-# name = "enable_api_indexable",
-# build_setting_default = False,
-# )
-#
-# config_setting(
-# name = "api_indexable_flag",
-# flag_values = {":enable_api_indexable": "True"},
-# )
-#
-# selects.config_setting_group(
-# name = "api_indexable",
-# match_any = [":api_indexable_flag"] + ADDITIONAL_API_INDEXABLE_SETTINGS,
-# visibility = ["//visibility:public"],
-# )
-# copybara:uncomment_end
+bool_flag(
+ name = "enable_api_indexable",
+ build_setting_default = False,
+)
+
+config_setting(
+ name = "api_indexable_flag",
+ flag_values = {":enable_api_indexable": "True"},
+)
+
+selects.config_setting_group(
+ name = "api_indexable",
+ match_any = [":api_indexable_flag"] + ADDITIONAL_API_INDEXABLE_SETTINGS,
+ visibility = ["//visibility:public"],
+)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
@@ -949,6 +971,8 @@ package_group(
"//learning/brain/tfrt/...",
"//learning/lib/ami/simple_ml/...",
"//learning/pathways/...",
+ "//learning/serving/contrib/tfrt/mlir/canonical_ops/...",
+ "//perftools/accelerators/xprof/integration_tests/...",
"//smartass/brain/configure/...",
"//tensorflow/...",
"//tensorflow_decision_forests/...",
@@ -967,12 +991,12 @@ package_group(name = "ndarray_tensor_allow_list")
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.
# If this is modified, then copy.bara.sky must also be modified.
-package_group(name = "types_whitelist")
+package_group(name = "types_allowlist")
# Packages that use StructuredTensors.
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
# LINT.IfChange
-package_group(name = "structured_tensor_whitelist")
+package_group(name = "structured_tensor_allowlist")
# LINT.ThenChange(copy.bara.sky)
filegroup(
@@ -1081,28 +1105,38 @@ tf_cc_shared_library(
linkstatic = 1,
per_os_targets = True,
roots = [
- "//tensorflow/c/experimental/filesystem:filesystem_interface",
- "//tensorflow/c/experimental/stream_executor:stream_executor",
- "//tensorflow/c:env",
- "//tensorflow/c:kernels",
- "//tensorflow/c:kernels_experimental",
- "//tensorflow/c:logging",
- "//tensorflow/c:ops",
- "//tensorflow/cc/saved_model:fingerprinting_impl",
- "//tensorflow/cc/saved_model:loader_lite_impl",
- "//tensorflow/cc/saved_model:metrics_impl",
- "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl",
- "//tensorflow/core/common_runtime:core_cpu_impl",
- "//tensorflow/core:framework_internal_impl",
- "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
- "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl",
- "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
- "//tensorflow/core:lib_internal_impl",
- "//tensorflow/core/profiler:profiler_impl",
- "//tensorflow/core/util:determinism", # Must be linked and exported to libtensorflow_framework.so.
- "//tensorflow/lite/kernels/shim:tf_kernel_shim",
- "//tensorflow/compiler/xla/stream_executor:stream_executor_impl",
- ] + tf_additional_binary_deps(),
+ "//tensorflow/c/experimental/filesystem:filesystem_interface",
+ "//tensorflow/c/experimental/stream_executor:stream_executor",
+ "//tensorflow/c:env",
+ "//tensorflow/c:kernels",
+ "//tensorflow/c:kernels_experimental",
+ "//tensorflow/c:logging",
+ "//tensorflow/c:ops",
+ "//tensorflow/cc/saved_model:fingerprinting_impl",
+ "//tensorflow/cc/saved_model:loader_lite_impl",
+ "//tensorflow/cc/saved_model:metrics_impl",
+ "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl",
+ "//tensorflow/core/common_runtime:core_cpu_impl",
+ "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
+ "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl",
+ "//tensorflow/core:framework_internal_impl",
+ "//tensorflow/core/framework:tensor",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
+ "//tensorflow/core:lib_internal_impl",
+ "//tensorflow/core/profiler:profiler_impl",
+ "//tensorflow/core/util:determinism", # Must be linked and exported to libtensorflow_framework.so.
+ "//tensorflow/lite/kernels/shim:tf_kernel_shim",
+ "//tensorflow/compiler/xla/stream_executor:stream_executor_impl",
+ "//tensorflow/tsl/framework:bfc_allocator",
+ "//tensorflow/tsl/framework:metrics",
+ ] + tf_additional_binary_deps() +
+ # TODO(b/259305727): Remove this select and include captured_function in macos builds.
+ select({
+ "//tensorflow:macos": [],
+ "//conditions:default": [
+ "//tensorflow/core/data:captured_function",
+ ],
+ }),
soversion = VERSION,
static_deps = PACKAGE_STATIC_DEPS,
visibility = ["//visibility:public"],
@@ -1193,6 +1227,9 @@ tf_cc_shared_library(
"//tensorflow:macos": ["//tensorflow:libtensorflow_framework.%s.dylib" % VERSION],
"//conditions:default": ["//tensorflow:libtensorflow_framework.so.%s" % VERSION],
}),
+ exports_filter = [
+ "//:__subpackages__",
+ ],
framework_so = [],
linkopts = select({
"//tensorflow:macos": [
@@ -1206,41 +1243,168 @@ tf_cc_shared_library(
}),
per_os_targets = True,
roots = [
+ "//tensorflow/c:c_api",
+ "//tensorflow/c/eager:c_api",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
- "//tensorflow/cc:const_op",
"//tensorflow/cc:scope",
- ],
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/data:standalone",
+ # Exports for pywrap_tensorflow_internal. Many of these are transitive
+ # depedencies of the above, but must be explicitly listed for
+ # cc_shared_library to work.
+ "//tensorflow/c/eager:c_api_experimental",
+ "//tensorflow/c/eager:c_api_internal",
+ "//tensorflow/c/eager:dlpack",
+ "//tensorflow/c/eager:tape",
+ "//tensorflow/c/eager:tfe_context_internal",
+ "//tensorflow/c/eager:tfe_op_internal",
+ "//tensorflow/c/eager:tfe_tensorhandle_internal",
+ "//tensorflow/c/experimental/gradients",
+ "//tensorflow/c/experimental/gradients/tape",
+ "//tensorflow/c/experimental/ops",
+ "//tensorflow/c:c_api_experimental",
+ "//tensorflow/c:c_api_internal",
+ "//tensorflow/c:c_api_no_xla",
+ "//tensorflow/c:checkpoint_reader",
+ "//tensorflow/c:tensor_interface",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/c:tf_tensor_internal",
+ "//tensorflow/cc/saved_model:loader",
+ "//tensorflow/compiler/mlir/lite/metrics:error_collector",
+ "//tensorflow/compiler/mlir/lite/python:flatbuffer_to_mlir",
+ "//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer",
+ "//tensorflow/compiler/mlir/lite/python:jax_to_tfl_flatbuffer",
+ "//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer",
+ "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model",
+ "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
+ "//tensorflow/compiler/mlir/lite/sparsity:sparsify_model",
+ "//tensorflow/compiler/mlir/python:mlir",
+ "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op",
+ "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl",
+ "//tensorflow/compiler/mlir/quantization/tensorflow:passes",
+ "//tensorflow/compiler/mlir/tensorflow",
+ "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
+ "//tensorflow/compiler/mlir/tensorflow:error_util",
+ "//tensorflow/compiler/mlir/tensorflow:export_graphdef",
+ "//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
+ "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
+ "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
+ "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
+ "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ "//tensorflow/core",
+ "//tensorflow/core/common_runtime/eager:context",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
+ "//tensorflow/core/config:flag_defs",
+ "//tensorflow/core/config:flags",
+ "//tensorflow/core/data/service:dispatcher_client",
+ "//tensorflow/core/data/service:grpc_util",
+ "//tensorflow/core/data/service:py_utils",
+ "//tensorflow/core/data/service:server_lib",
+ "//tensorflow/core/debug",
+ "//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/function/runtime_client:runtime_client_cc",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/clusters:single_machine",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/costs:graph_memory",
+ "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",
+ "//tensorflow/core/grappler/optimizers:meta_optimizer",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:grappler_item_builder",
+ "//tensorflow/core/kernels:data_service_ops",
+ "//tensorflow/core/kernels:dataset_ops",
+ "//tensorflow/core/platform:logging",
+ "//tensorflow/core/platform:path",
+ "//tensorflow/core/platform:stacktrace_handler",
+ "//tensorflow/core/platform:statusor",
+ "//tensorflow/core/platform:stringpiece",
+ "//tensorflow/core/platform:types",
+ "//tensorflow/core/profiler/internal:print_model_analysis",
+ "//tensorflow/core/profiler/lib:traceme",
+ "//tensorflow/core/profiler/rpc/client:profiler_client_impl",
+ "//tensorflow/core/profiler/rpc:profiler_server_impl",
+ "//tensorflow/core/util:managed_stack_trace",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:direct_session",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:reader_base",
+ "//tensorflow/core:script_ops_op_lib",
+ "//tensorflow/distribute/experimental/rpc/kernels:rpc_ops",
+ "//tensorflow/dtensor/cc:dtensor_device_cc",
+ "//tensorflow/dtensor/cc:tensor_layout",
+ "//tensorflow/lite/c:common",
+ "//tensorflow/lite/core/api",
+ "//tensorflow/lite/delegates/flex:delegate",
+ "//tensorflow/lite/kernels/internal:compatibility",
+ "//tensorflow/lite/kernels:builtin_ops",
+ "//tensorflow/lite/kernels:reference_ops",
+ "//tensorflow/lite/schema:schema_fbs",
+ "//tensorflow/lite/toco/logging:conversion_log_util",
+ "//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc",
+ "//tensorflow/lite/toco:model_flags_proto_cc",
+ "//tensorflow/lite/toco:toco_convert",
+ "//tensorflow/lite/toco:toco_flags_proto_cc",
+ "//tensorflow/lite/toco:toco_graphviz_dump_options",
+ "//tensorflow/lite/toco:toco_port",
+ "//tensorflow/lite/toco:toco_tooling",
+ "//tensorflow/lite/toco:tooling_util",
+ "//tensorflow/lite/toco:types_proto_cc",
+ "//tensorflow/lite:framework",
+ "//tensorflow/lite:shared_library",
+ "//tensorflow/lite:stateful_error_reporter",
+ "//tensorflow/lite:string_util",
+ "//tensorflow/lite:util",
+ "//tensorflow/python/grappler:cost_analyzer_lib",
+ "//tensorflow/tools/graph_transforms:transform_graph_lib",
+ ] + (tf_monitoring_python_deps() +
+ tf_additional_plugin_deps() +
+ tf_additional_profiler_deps()) + if_xla_available([
+ "//tensorflow/compiler/aot:tfcompile_lib",
+ ]) + if_static(extra_deps = [
+ "//tensorflow/core/platform:tensor_float_32_utils",
+ "//tensorflow/core/platform:enable_tf2_utils",
+ ]) + if_oss([
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_session",
+ ]),
soversion = VERSION,
static_deps = PACKAGE_STATIC_DEPS,
visibility = ["//visibility:public"],
win_def_file = ":tensorflow_filtered_def_file",
- deps = [
- "//tensorflow/c:c_api",
- "//tensorflow/c:env",
- "//tensorflow/c:kernels",
- "//tensorflow/c:kernels_experimental",
- "//tensorflow/c:logging",
- "//tensorflow/c:ops",
- "//tensorflow/c/eager:c_api",
- "//tensorflow/c/experimental/filesystem:filesystem_interface",
- "//tensorflow/c/experimental/stream_executor:stream_executor",
- "//tensorflow/cc/saved_model:fingerprinting_impl",
- "//tensorflow/cc/saved_model:loader_lite_impl",
- "//tensorflow/cc/saved_model:metrics_impl",
- "//tensorflow/core:framework_internal_impl",
- "//tensorflow/core:lib_internal_impl",
- "//tensorflow/core:tensorflow",
- "//tensorflow/core/data:standalone",
- "//tensorflow/core/common_runtime:core_cpu_impl",
- "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
- "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl",
- "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
- "//tensorflow/core/profiler:profiler_impl",
- "//tensorflow/core/util:determinism",
- "//tensorflow/lite/kernels/shim:tf_kernel_shim",
- "//tensorflow/compiler/xla/stream_executor:stream_executor_impl",
- ] + tf_additional_binary_deps(),
+)
+
+# To avoid duplication, check that the C++ or python library does not depend on
+# the stream executor cuda plugins. Targets that want to use cuda APIs should
+# instead depend on the dummy plugins in //tensorflow/tsl/platform/default/build_config
+# and use header only targets.
+# TODO(ddunleavy): This seems completely broken. :tensorflow_cc depends on
+# cuda_platform from tf_additional_binary_deps and this doesn't break.
+check_deps(
+ name = "cuda_plugins_check_deps",
+ disallowed_deps = if_static(
+ [],
+ otherwise = [
+ "//tensorflow/compiler/xla/stream_executor/cuda:all_runtime",
+ "//tensorflow/compiler/xla/stream_executor/cuda:cuda_driver",
+ "//tensorflow/compiler/xla/stream_executor/cuda:cuda_platform",
+ "//tensorflow/compiler/xla/stream_executor/cuda:cudnn_plugin",
+ "//tensorflow/compiler/xla/stream_executor/cuda:cufft_plugin",
+ "//tensorflow/compiler/xla/stream_executor/cuda:curand_plugin",
+ "//tensorflow/compiler/xla/stream_executor:cuda_platform",
+ ],
+ ),
+ deps = if_cuda([
+ "//tensorflow:tensorflow_cc",
+ "//tensorflow/python:pywrap_tensorflow_internal",
+ ]),
)
# ** Targets for Windows build (start) **
@@ -1344,7 +1508,7 @@ genrule(
"//tensorflow/c/eager:headers",
"//tensorflow/cc:headers",
"//tensorflow/core:headers",
- "//tensorflow/stream_executor:stream_executor_install_hdrs",
+ "//tensorflow/compiler/xla/stream_executor:stream_executor_install_hdrs",
],
outs = ["include"],
cmd = """
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 3bb0bb91ba6..cd3cbac7a96 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -103,8 +103,6 @@
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
-# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
-# directories.
# TODO(gunan): Find a better location for this code snippet.
from tensorflow.python.framework import load_library as _ll
from tensorflow.python.lib.io import file_io as _fi
@@ -146,6 +144,11 @@ def _running_from_pip_package():
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
+if _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH", ""):
+ _ll.load_pluggable_device_library(
+ _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH")
+ )
+
# Add module aliases
if hasattr(_current_module, 'keras'):
# It is possible that keras is a lazily loaded module, which might break when
diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py
index f11fedce109..6c42fea562f 100644
--- a/tensorflow/api_template_v1.__init__.py
+++ b/tensorflow/api_template_v1.__init__.py
@@ -145,8 +145,6 @@
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
-# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
-# directories.
# TODO(gunan): Find a better location for this code snippet.
from tensorflow.python.framework import load_library as _ll
from tensorflow.python.lib.io import file_io as _fi
@@ -187,6 +185,11 @@ def _running_from_pip_package():
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
+if _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH", ""):
+ _ll.load_pluggable_device_library(
+ _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH")
+ )
+
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if _typing.TYPE_CHECKING:
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index dbd90e1d01f..3c1568b7091 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -2,7 +2,7 @@
# C API for TensorFlow, for use by client language bindings.
load("@bazel_skylib//lib:selects.bzl", "selects")
-load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
+load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"check_deps",
@@ -18,6 +18,7 @@ load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_cuda_cc_test")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
@@ -39,6 +40,7 @@ filegroup(
"tf_tensor.h",
"tf_tstring.h",
"//tensorflow/core/platform:ctstring",
+ "//tensorflow/tsl/c:headers",
] + if_tensorrt([
"//tensorflow/compiler/tf2tensorrt:headers",
]),
@@ -60,7 +62,8 @@ filegroup(
"*test*",
],
) + [
- "//tensorflow/core/platform:ctstring",
+ "//tensorflow/tsl/c:srcs",
+ "//tensorflow/tsl/platform:ctstring",
"//tensorflow/cc:srcs_no_runtime",
"//tensorflow/core/distributed_runtime:server_lib.h",
],
@@ -79,6 +82,7 @@ cc_library(
"tf_buffer_internal.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
+ "//tensorflow/tsl/c:tsl_status_internal_headers",
],
visibility = [
"//tensorflow/core:__pkg__",
@@ -86,6 +90,22 @@ cc_library(
],
)
+cc_library(
+ name = "c_api_headers",
+ hdrs = [
+ "c_api.h",
+ "c_api_macros.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tf_attrtype",
+ ":tf_buffer",
+ ":tf_datatype",
+ ":tf_status_headers",
+ ":tf_tstring",
+ ],
+)
+
tf_cuda_library(
name = "c_api_internal",
hdrs = [
@@ -184,6 +204,7 @@ tf_cuda_library(
":tf_tensor_internal",
":tf_tstring",
"//tensorflow/core/platform:tstring",
+ "//tensorflow/tsl/c:tsl_status",
] + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
@@ -213,7 +234,7 @@ tf_cuda_library(
],
copts = tf_copts(),
visibility = [
- "//tensorflow/c:__subpackages__",
+ "//tensorflow:__subpackages__",
"//tensorflow/python:__subpackages__",
],
deps = [
@@ -273,6 +294,7 @@ tf_cuda_library(
hdrs = [
"tf_status.h",
"tf_status_internal.h",
+ "//tensorflow/tsl/c:tsl_status_internal_headers",
],
visibility = [
"//tensorflow/c:__subpackages__",
@@ -285,7 +307,11 @@ tf_cuda_library(
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
"//tensorflow/core/transforms:__subpackages__",
],
- deps = select({
+ deps = [
+ "//tensorflow/tsl/platform:status",
+ "//tensorflow/tsl/c:tsl_status",
+ "//tensorflow/tsl/c:tsl_status_internal",
+ ] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
@@ -297,7 +323,10 @@ tf_cuda_library(
filegroup(
name = "tf_status_internal_headers",
- srcs = ["tf_status_internal.h"],
+ srcs = [
+ "tf_status_internal.h",
+ "//tensorflow/tsl/c:tsl_status_internal_headers",
+ ],
visibility = [
"//tensorflow/python:__subpackages__",
],
@@ -331,9 +360,11 @@ cc_library(
name = "tf_status",
srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"],
+ copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":tf_status_internal",
+ "//tensorflow/tsl/c:tsl_status",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
@@ -344,22 +375,13 @@ cc_library(
}),
)
-tf_cc_test(
- name = "tf_status_test",
- srcs = ["tf_status_test.cc"],
- deps = [
- ":tf_status",
- ":tf_status_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
cc_library(
name = "tf_status_headers",
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/tsl/c:tsl_status",
+ ],
)
cc_library(
@@ -374,10 +396,12 @@ cc_library(
"tf_tensor.h",
"tf_tstring.h",
],
+ copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:tstring",
+ "//tensorflow/tsl/c:tsl_status",
],
)
@@ -406,6 +430,7 @@ cc_library(
name = "tf_datatype",
srcs = ["tf_datatype.cc"],
hdrs = ["tf_datatype.h"],
+ copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
@@ -422,6 +447,7 @@ cc_library(
name = "tf_tensor",
srcs = ["tf_tensor.cc"],
hdrs = ["tf_tensor.h"],
+ copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
@@ -475,6 +501,7 @@ cc_library(
hdrs = [
"tf_buffer.h",
],
+ copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":tf_buffer_internal",
@@ -570,24 +597,9 @@ tf_cuda_library(
deps = [
":tf_status",
":tf_status_internal",
- ] + select({
- "//tensorflow:android": [
- "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
- ],
- "//conditions:default": [
- "//tensorflow/core:lib",
- ],
- }),
-)
-
-tf_cc_test(
- name = "tf_status_helper_test",
- srcs = ["tf_status_helper_test.cc"],
- deps = [
- ":tf_status_helper",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
+ "//tensorflow/core/platform:errors",
+ "//tensorflow/core/platform:status",
+ "//tensorflow/tsl/c:tsl_status_helper",
],
)
@@ -804,7 +816,6 @@ tf_cuda_cc_test(
],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
- # linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":c_api_internal",
@@ -849,6 +860,7 @@ tf_cc_test(
data = [
"testdata/tf_record",
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
+ "//tensorflow/core/common_runtime/next_pluggable_device/c:test_next_pluggable_device_plugin.so",
],
extra_copts = if_google(["-DTENSORFLOW_NO_SHARED_OBJECTS=1"]),
linkopts = select({
@@ -861,7 +873,6 @@ tf_cc_test(
],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
- # linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":c_api_experimental",
@@ -934,7 +945,6 @@ tf_cuda_cc_test(
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
- # linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":env",
@@ -955,7 +965,6 @@ tf_cuda_cc_test(
tags = ["no_cuda_on_cpu_tap"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
- # linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":kernels",
@@ -982,7 +991,6 @@ tf_cc_test(
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
- # linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":ops",
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 96e1268f62d..da62fc35bc0 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -117,7 +117,13 @@ const char* TF_Version() { return TF_VERSION_STRING; }
// --------------------------------------------------------------------------
// --------------------------------------------------------------------------
-TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
+TF_SessionOptions* TF_NewSessionOptions() {
+ TF_SessionOptions* out = new TF_SessionOptions;
+ // Disable optimizations for static graph to allow calls to Session::Extend.
+ out->options.config.mutable_experimental()
+ ->set_disable_optimize_for_static_graph(true);
+ return out;
+}
void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
void TF_SetTarget(TF_SessionOptions* options, const char* target) {
@@ -129,6 +135,9 @@ void TF_SetConfig(TF_SessionOptions* options, const void* proto,
if (!options->options.config.ParseFromArray(proto, proto_len)) {
status->status = InvalidArgument("Unparseable ConfigProto");
}
+ // Disable optimizations for static graph to allow calls to Session::Extend.
+ options->options.config.mutable_experimental()
+ ->set_disable_optimize_for_static_graph(true);
}
void TF_TensorFromProto(const TF_Buffer* from, TF_Tensor* to,
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 523f5c6e609..3a05e1e64db 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -758,3 +758,9 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
delete lib_handle;
}
+
+void TF_GraphRemoveFunction(TF_Graph* g, const char* func_name,
+ TF_Status* status) {
+ tensorflow::mutex_lock l(g->mu);
+ status->status = g->graph.mutable_flib_def()->RemoveFunction(func_name);
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index ac41bb5a9ca..aec1e875eaf 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -329,6 +329,12 @@ TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
TF_Library* lib_handle);
+// Removes `func_name` from `g`. If `func_name` is not in `g`, an error will be
+// returned.
+TF_CAPI_EXPORT extern void TF_GraphRemoveFunction(TF_Graph* g,
+ const char* func_name,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index e47b7d0b0f7..63013c3fe46 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "absl/types/optional.h"
+#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/c/eager/c_api.h"
@@ -255,5 +256,110 @@ TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
#endif // !defined(PLATFORM_WINDOWS)
}
+TEST(CAPI_EXPERIMENTAL, LibraryNextPluggableDeviceLoadFunctions) {
+ // TODO(penpornk): Enable this test on Windows.
+#if !defined(PLATFORM_WINDOWS)
+#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
+ // Load the library.
+ TF_Status* status = TF_NewStatus();
+ string lib_path =
+ tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
+ "tensorflow", "core", "common_runtime", "next_pluggable_device", "c",
+ "test_next_pluggable_device_plugin.so"));
+ TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
+ TF_Code code = TF_GetCode(status);
+ string status_msg(TF_Message(status));
+ TF_DeleteStatus(status);
+ ASSERT_EQ(TF_OK, code) << status_msg;
+ TF_DeletePluggableDeviceLibraryHandle(lib);
+#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
+#endif // !defined(PLATFORM_WINDOWS)
+}
+
+void DefineFunction(const char* name, TF_Function** func,
+ const char* description = nullptr,
+ bool append_hash = false) {
+ std::unique_ptr func_graph(
+ TF_NewGraph(), TF_DeleteGraph);
+ std::unique_ptr s(TF_NewStatus(),
+ TF_DeleteStatus);
+
+ TF_Operation* feed = Placeholder(func_graph.get(), s.get());
+ TF_Operation* neg = Neg(feed, func_graph.get(), s.get());
+
+ TF_Output inputs[] = {{feed, 0}};
+ TF_Output outputs[] = {{neg, 0}};
+ *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1,
+ /*opers=*/nullptr, 1, inputs, 1, outputs,
+ /*output_names=*/nullptr,
+ /*opts=*/nullptr, description, s.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
+ ASSERT_NE(*func, nullptr);
+}
+
+class CApiExperimentalFunctionTest : public ::testing::Test {
+ protected:
+ CApiExperimentalFunctionTest()
+ : s_(TF_NewStatus()), func_graph_(TF_NewGraph()), func_(nullptr) {}
+
+ void SetUp() override {}
+
+ ~CApiExperimentalFunctionTest() override {
+ TF_DeleteFunction(func_);
+ TF_DeleteGraph(func_graph_);
+ TF_DeleteStatus(s_);
+ }
+
+ const char* func_name_ = "MyFunc";
+ TF_Status* s_;
+ TF_Graph* func_graph_;
+ TF_Function* func_;
+};
+
+TEST_F(CApiExperimentalFunctionTest, GraphRemoveFunction) {
+ TF_Function* funcs[1];
+ DefineFunction(func_name_, &func_);
+
+ TF_GraphCopyFunction(func_graph_, func_, nullptr, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ EXPECT_EQ(TF_GraphNumFunctions(func_graph_), 1);
+ EXPECT_EQ(TF_GraphGetFunctions(func_graph_, funcs, 1, s_), 1);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_GraphRemoveFunction(func_graph_, func_name_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ EXPECT_EQ(TF_GraphNumFunctions(func_graph_), 0);
+ EXPECT_EQ(TF_GraphGetFunctions(func_graph_, funcs, 1, s_), 0);
+
+ TF_DeleteFunction(funcs[0]);
+}
+
+TEST_F(CApiExperimentalFunctionTest, EmptyGraphRemoveNonExistentFunction) {
+ TF_GraphRemoveFunction(func_graph_, "wrong_name", s_);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Tried to remove non-existent function 'wrong_name'."),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiExperimentalFunctionTest, GraphRemoveNonExistentFunction) {
+ TF_Function* funcs[1];
+ DefineFunction(func_name_, &func_);
+
+ TF_GraphCopyFunction(func_graph_, func_, nullptr, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ EXPECT_EQ(TF_GraphNumFunctions(func_graph_), 1);
+ EXPECT_EQ(TF_GraphGetFunctions(func_graph_, funcs, 1, s_), 1);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_GraphRemoveFunction(func_graph_, "wrong_name", s_);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Tried to remove non-existent function 'wrong_name'."),
+ string(TF_Message(s_)));
+ TF_DeleteFunction(funcs[0]);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 537b61f3558..a13a1458553 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -185,7 +185,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
if (control_output_names) {
control_output_names_vec.reserve(ncontrol_outputs);
for (int i = 0; i < ncontrol_outputs; ++i) {
- control_output_names_vec.push_back(string(output_names[i]));
+ control_output_names_vec.push_back(string(control_output_names[i]));
}
}
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index 9722841691f..79d2841b724 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -211,6 +211,14 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
std::string getTF_OutputDebugString(TF_Output node);
+// Set whether to propagate assigned device information when constructing a new
+// Graph from a GraphDef. By default assigned device information is not copied
+// and is re-computed by the runtime.
+inline void TF_ImportGraphDefOptionsSetPropagateDeviceSpec(
+ TF_ImportGraphDefOptions* opts, unsigned char propagate_device_spec) {
+ opts->opts.propagate_device_spec = propagate_device_spec;
+}
+
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index c1aeb831bce..43dfe5155de 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -9,16 +9,13 @@ load(
"tf_cuda_library",
)
load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "internal_tfrt_deps")
-load(
- "//tensorflow/core/platform:build_config.bzl",
- "tf_kernel_tests_linkstatic",
-)
load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
@@ -131,7 +128,7 @@ filegroup(
"tfe_tensorhandle_internal.h",
],
visibility = [
- "//tensorflow/core/function:__pkg__",
+ "//tensorflow/core/function/runtime_client:__pkg__",
"//tensorflow/python:__subpackages__",
],
)
@@ -256,7 +253,6 @@ tf_cuda_cc_test(
"gradients_test.cc",
],
args = ["--heap_check="],
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":abstract_context",
@@ -293,7 +289,6 @@ tf_cuda_cc_test(
"unified_api_test.cc",
],
args = ["--heap_check="],
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156
deps = [
":c_api_experimental",
@@ -337,7 +332,6 @@ tf_cuda_cc_test(
"gradient_checker_test.cc",
],
args = ["--heap_check="],
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"no_cuda_asan", # b/175330074
],
@@ -755,7 +749,10 @@ tf_cuda_cc_test(
tags = [
"no_oss", # TODO(b/200848572)
"no_windows",
+ # TODO(b/136478427): sanitizers report issues due to unclean exit.
"noasan", # leaks gRPC server instances
+ "nomsan", # b/229991646: use of destructed memory due to unclean exit.
+ "notsan", # b/259602430: race on destructed mutex due to unclean exit.
],
deps = [
":c_api",
@@ -885,9 +882,9 @@ tf_cuda_library(
}) + [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_status_helper",
- "//tensorflow/core/distributed_runtime/coordination:coordination_service_agent",
"//tensorflow/core/distributed_runtime/coordination:coordination_service_error_util",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@@ -900,6 +897,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core:gpu_runtime",
+ "//tensorflow/tsl/distributed_runtime/coordination:coordination_service_agent",
],
alwayslink = 1,
)
@@ -911,7 +909,6 @@ tf_cuda_cc_test(
"c_api_experimental_test.cc",
],
args = ["--heap_check="],
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":c_api",
@@ -934,7 +931,6 @@ tf_cuda_cc_test(
"c_api_unified_experimental_test.cc",
],
args = ["--heap_check="],
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":c_api",
@@ -1015,7 +1011,7 @@ cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"],
- copts = [
+ copts = tf_copts() + [
"-fexceptions",
"-fno-strict-aliasing",
],
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 13a9c797235..e3199b204f6 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -68,7 +68,6 @@ limitations under the License.
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) && \
!defined(PLATFORM_FUCHSIA)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
-#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE && !PLATFORM_FUCHSIA
#if !defined(IS_MOBILE_PLATFORM)
@@ -123,12 +122,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
opts->session_options.options,
static_cast(
opts->device_placement_policy),
- opts->async, opts->use_tfrt_distributed_runtime);
-#if !defined(IS_MOBILE_PLATFORM)
- tfrt_context->SetDistributedManager(
- tfrt::tf::CreateDistributedManagerContext(
- tfrt_context->GetCoreRuntime()->GetHostContext()));
-#endif // !IS_MOBILE_PLATFORM
+ opts->async);
return tensorflow::wrap(tfrt_context);
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index 149c6062d23..7eb22ed2c7c 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include
#include "absl/strings/match.h"
+#include "absl/time/time.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
-#include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
#include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/strcat.h"
+#include "tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.h"
using tensorflow::string;
@@ -539,11 +540,6 @@ void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
options->use_tfrt = use_tfrt;
}
-void TFE_ContextOptionsSetTfrtDistributedRuntime(
- TFE_ContextOptions* options, bool use_tfrt_distributed_runtime) {
- options->use_tfrt_distributed_runtime = use_tfrt_distributed_runtime;
-}
-
TFE_CancellationManager* TFE_NewCancellationManager() {
return tensorflow::wrap(new tensorflow::CancellationManager);
}
@@ -571,8 +567,10 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
status->status = ::tensorflow::OkStatus();
}
-TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue) {
- return new TFE_Executor(is_async, enable_streaming_enqueue);
+TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue,
+ int in_flight_nodes_limit) {
+ return new TFE_Executor(is_async, enable_streaming_enqueue,
+ in_flight_nodes_limit);
}
void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
@@ -785,7 +783,7 @@ void TFE_InsertConfigKeyValue(TFE_Context* ctx, const char* key,
const char* value, TF_Status* status) {
tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
tensorflow::unwrap(ctx)->GetDistributedManager();
- tensorflow::CoordinationServiceAgent* coord_agent =
+ tsl::CoordinationServiceAgent* coord_agent =
dist_mgr->GetCoordinationServiceAgent();
if (coord_agent == nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
@@ -799,7 +797,7 @@ void TFE_GetConfigKeyValue(TFE_Context* ctx, const char* key,
TF_Buffer* value_buf, TF_Status* status) {
tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
tensorflow::unwrap(ctx)->GetDistributedManager();
- tensorflow::CoordinationServiceAgent* coord_agent =
+ tsl::CoordinationServiceAgent* coord_agent =
dist_mgr->GetCoordinationServiceAgent();
if (coord_agent == nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
@@ -824,7 +822,7 @@ void TFE_DeleteConfigKeyValue(TFE_Context* ctx, const char* key,
TF_Status* status) {
tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
tensorflow::unwrap(ctx)->GetDistributedManager();
- tensorflow::CoordinationServiceAgent* coord_agent =
+ tsl::CoordinationServiceAgent* coord_agent =
dist_mgr->GetCoordinationServiceAgent();
if (coord_agent == nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
@@ -838,7 +836,7 @@ void TFE_ReportErrorToCluster(TFE_Context* ctx, int error_code,
const char* error_message, TF_Status* status) {
tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
tensorflow::unwrap(ctx)->GetDistributedManager();
- tensorflow::CoordinationServiceAgent* coord_agent =
+ tsl::CoordinationServiceAgent* coord_agent =
dist_mgr->GetCoordinationServiceAgent();
if (coord_agent == nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
@@ -854,7 +852,7 @@ void TFE_GetTaskStates(TFE_Context* ctx, const TF_Buffer& tasks, void* states,
TF_Status* status) {
tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
tensorflow::unwrap(ctx)->GetDistributedManager();
- tensorflow::CoordinationServiceAgent* coord_agent =
+ tsl::CoordinationServiceAgent* coord_agent =
dist_mgr->GetCoordinationServiceAgent();
if (coord_agent == nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
@@ -890,3 +888,18 @@ void TFE_GetTaskStates(TFE_Context* ctx, const TF_Buffer& tasks, void* states,
}
status->status = tensorflow::OkStatus();
}
+
+void TFE_WaitAtBarrier(TFE_Context* ctx, const char* barrier_id,
+ int64_t barrier_timeout_in_ms, TF_Status* status) {
+ tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
+ tensorflow::unwrap(ctx)->GetDistributedManager();
+ tsl::CoordinationServiceAgent* coord_agent =
+ dist_mgr->GetCoordinationServiceAgent();
+ if (coord_agent == nullptr) {
+ status->status = tensorflow::errors::FailedPrecondition(
+ "Coordination service is not enabled.");
+ return;
+ }
+ status->status = coord_agent->WaitAtBarrier(
+ barrier_id, absl::Milliseconds(barrier_timeout_in_ms), {});
+}
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index 704a093fbab..95d833f6f47 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -294,10 +294,6 @@ TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);
-// Sets whether to use TFRT distributed runtime
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrtDistributedRuntime(
- TFE_ContextOptions* options, bool use_tfrt_distributed_runtime);
-
// Returns the context_id from the EagerContext which is used by the
// EagerService to maintain consistency between client and worker. The
// context_id is initialized with a dummy value and is later set when the worker
@@ -333,8 +329,16 @@ typedef struct TFE_Executor TFE_Executor;
// Creates a new eager Executor. Nodes in one executor are guaranteed to be
// executed in sequence. Assigning nodes to different executors allows executing
// nodes in parallel.
+// in_flight_nodes_limit: when is_async is true, this value controls the
+// maximum number of in flight async nodes. Enqueuing of additional async ops
+// after the limit is reached blocks until some inflight nodes finishes.
+// The effect is bounding the memory held by inflight TensorHandles that are
+// referenced by the inflight nodes.
+// A recommended value has not been established.
+// A value of 0 removes the limit, which is the behavior of TensorFlow 2.11.
+// When is_async is false, the value is ignored.
TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(
- bool is_async, bool enable_streaming_enqueue);
+ bool is_async, bool enable_streaming_enqueue, int in_flight_nodes_limit);
// Deletes the eager Executor without waiting for enqueued nodes. Please call
// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to
@@ -724,6 +728,11 @@ TF_CAPI_EXPORT extern void TFE_GetTaskStates(TFE_Context* ctx,
const TF_Buffer& tasks,
void* states, TF_Status* status);
+TF_CAPI_EXPORT extern void TFE_WaitAtBarrier(TFE_Context* ctx,
+ const char* barrier_id,
+ int64_t barrier_timeout_in_ms,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc
index f900de59060..68dbafc4d2a 100644
--- a/tensorflow/c/eager/c_api_experimental_test.cc
+++ b/tensorflow/c/eager/c_api_experimental_test.cc
@@ -220,7 +220,8 @@ TEST(CAPI, ExecutorContextDestructionOrder) {
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(
- /*is_async=*/false, /*enable_streaming_enqueue=*/true);
+ /*is_async=*/false, /*enable_streaming_enqueue=*/true,
+ /*in_flight_nodes_limit=*/0);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteContext(ctx);
@@ -233,7 +234,8 @@ TEST(CAPI, ExecutorContextDestructionOrder) {
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(
- /*is_async=*/false, /*enable_streaming_enqueue=*/true);
+ /*is_async=*/false, /*enable_streaming_enqueue=*/true,
+ /*in_flight_nodes_limit=*/0);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteExecutor(executor);
@@ -275,7 +277,8 @@ TEST(CAPI, Function_ident_CPU) {
for (bool async : {false, true, false}) {
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
TFE_Executor* executor = TFE_NewExecutor(
- /*is_async=*/async, /*enable_streaming_enqueue=*/true);
+ /*is_async=*/async, /*enable_streaming_enqueue=*/true,
+ /*in_flight_nodes_limit=*/0);
TFE_ContextSetExecutorForThread(ctx, executor);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@@ -327,7 +330,8 @@ void Executor_MatMul_CPU(bool async) {
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
TFE_Executor* executor = TFE_NewExecutor(
- /*is_async=*/async, /*enable_streaming_enqueue=*/true);
+ /*is_async=*/async, /*enable_streaming_enqueue=*/true,
+ /*in_flight_nodes_limit=*/0);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 8bec998681e..eff96826822 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -34,11 +34,6 @@ struct TFE_ContextOptions {
TFE_DEVICE_PLACEMENT_SILENT};
// If true, use TFRT backend
bool use_tfrt = false;
- // This option is effective only when use_tfrt is true. If true, TFRT will use
- // native TFRT distributed runtime. Otherwise, TFRT will use current runtime's
- // distributed runtime. Note that TFRT distributed runtime is in development
- // and not functionally complete.
- bool use_tfrt_distributed_runtime = false;
// Whether to run elementary eager ops wrapped in a call op.
bool run_eager_op_as_function = false;
// Whether to rewrite jit_compile functions.
diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h
index 9efb2fa85d6..4f96992e739 100644
--- a/tensorflow/c/eager/immediate_execution_distributed_manager.h
+++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h
@@ -20,8 +20,11 @@ limitations under the License.
#include "tensorflow/core/platform/status.h"
-namespace tensorflow {
+namespace tsl {
class CoordinationServiceAgent;
+}
+
+namespace tensorflow {
class ImmediateExecutionContext;
class ServerDef;
class WorkerEnv;
@@ -32,19 +35,19 @@ class ImmediateExecutionDistributedManager {
virtual ~ImmediateExecutionDistributedManager() {}
// Set up distributed execution environment on local and remote tasks.
- // When `reset_context` is true, initialize new cluster context state based on
- // cluster configurations provided in `server_def`; otherwise, update existing
- // context state with the provided `server_def`.
- // Contexts created on remote tasks will be considered stale and garbage
- // collected after `keep_alive_secs` of inactivity.
+ // When `reset_context` is true, initialize new cluster context state based
+ // on cluster configurations provided in `server_def`; otherwise, update
+ // existing context state with the provided `server_def`. Contexts created
+ // on remote tasks will be considered stale and garbage collected after
+ // `keep_alive_secs` of inactivity.
virtual Status SetOrUpdateServerDef(const ServerDef& server_def,
bool reset_context,
int keep_alive_secs) = 0;
- // Set up a multi-client distributed execution environment. Must be called on
- // all tasks in the cluster.
- // This call internally coordinates with other tasks to initialize the eager
- // context and TF server for multi-client execution.
+ // Set up a multi-client distributed execution environment. Must be called
+ // on all tasks in the cluster. This call internally coordinates with other
+ // tasks to initialize the eager context and TF server for multi-client
+ // execution.
virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0;
// Check if the remote task is alive.
@@ -52,7 +55,7 @@ class ImmediateExecutionDistributedManager {
bool* is_alive) = 0;
// Get pointer to the coordination service agent instance.
- virtual CoordinationServiceAgent* GetCoordinationServiceAgent() = 0;
+ virtual tsl::CoordinationServiceAgent* GetCoordinationServiceAgent() = 0;
};
} // namespace tensorflow
diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD
index e528a7070ab..0de029ff449 100644
--- a/tensorflow/c/eager/parallel_device/BUILD
+++ b/tensorflow/c/eager/parallel_device/BUILD
@@ -6,6 +6,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
index 727c1f83396..fd054c9af9a 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
@@ -66,7 +66,8 @@ using ExecutorPtr = std::unique_ptr;
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
- explicit DeviceThread(const std::string& device, const bool is_async)
+ explicit DeviceThread(const std::string& device, const bool is_async,
+ const int in_flight_nodes_limit)
: status_(TF_NewStatus()),
// If the context's default exector is set to async, re-using that in
// each thread would cause collectives to deadlock. For consistency we
@@ -75,7 +76,9 @@ class DeviceThread {
// TODO(allenl): We should have an async API that works with the
// parallel device.
device_(device),
- executor_(TFE_NewExecutor(is_async, /*enable_streaming_enqueue=*/true)),
+ executor_(
+ TFE_NewExecutor(is_async, /*enable_streaming_enqueue=*/true,
+ /*in_flight_nodes_limit=*/in_flight_nodes_limit)),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
@@ -282,13 +285,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
}
ParallelDevice::ParallelDevice(const std::vector& devices,
- const bool is_async)
+ bool is_async, int in_flight_nodes_limit)
: underlying_devices_(devices),
default_cancellation_manager_(absl::make_unique()) {
device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
- device_threads_.emplace_back(
- new DeviceThread(devices[device_index].c_str(), is_async));
+ device_threads_.emplace_back(new DeviceThread(
+ devices[device_index].c_str(), is_async, in_flight_nodes_limit));
}
}
@@ -365,6 +368,26 @@ void ParallelDevice::StartExecute(TFE_Context* context,
}
}
+void ParallelDevice::StartExecute(
+ TFE_Context* context, const std::vector& inputs,
+ const char* operation_name, const TFE_OpAttrs* attributes,
+ int expected_max_outputs, CancellationManager& cancellation_manager,
+ absl::optional step_id) const {
+ for (int device_index = 0; device_index < underlying_devices_.size();
+ ++device_index) {
+ DeviceThread* device_thread = device_threads_[device_index].get();
+ std::vector device_inputs;
+ device_inputs.reserve(inputs.size());
+ for (int input_index = 0; input_index < inputs.size(); ++input_index) {
+ // Parallel tensors are divided between operations by device.
+ device_inputs.push_back(inputs[input_index][device_index].get());
+ }
+ device_thread->StartExecute(
+ context, operation_name, std::move(device_inputs), attributes,
+ expected_max_outputs, cancellation_manager, step_id);
+ }
+}
+
void ParallelDevice::AsyncWait(TFE_Context* context, TF_Status* status) const {
StatusPtr first_bad_status(nullptr);
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h
index 80f81dd47a4..01581f40e05 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#include
+#include
#include
#include
@@ -44,6 +45,8 @@ class TensorHandleDeleter {
}
};
+// TODO(b/256016071): Replace this with `Safe_TFE_TensorHandlePtr` when
+// `Safe_TFE_TensorHandlePtr` is marked to be compatible on non-prod env.
using TensorHandlePtr = std::unique_ptr;
class ParallelTensor;
@@ -56,7 +59,7 @@ class ParallelDevice {
// Eager async execution is only supported when remote eager is not in use
// (b/157523095).
explicit ParallelDevice(const std::vector& devices,
- const bool is_async = false);
+ bool is_async = false, int in_flight_nodes_limit = 0);
~ParallelDevice();
@@ -118,12 +121,24 @@ class ParallelDevice {
//
// Set step_id to configure the step id used for rendezvous creation. step id
// of value -1 is reserved for global rendezvous and should not be set here.
+ //
+ // This function is overloaded so that if the inputs are constructed from
+ // `TensorWithLayout` we can use the one with `TensorHandlePtr` but
+ // if the inputs are directly `ParallelTensor` (for example, in the case of
+ // custom device execution) we can use the one with `ParallelTensor`.
void StartExecute(TFE_Context* context,
const std::vector& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs,
CancellationManager& cancellation_manager,
- absl::optional step_id = absl::nullopt) const;
+ std::optional step_id = std::nullopt) const;
+
+ void StartExecute(TFE_Context* context,
+ const std::vector& inputs,
+ const char* operation_name, const TFE_OpAttrs* attributes,
+ int expected_max_outputs,
+ CancellationManager& cancellation_manager,
+ std::optional step_id = std::nullopt) const;
// Blocks until the previous `StartExecute` has run `TFE_Execute` on each
// device. If is_async=false (constructor argument) this means the ops have
@@ -189,6 +204,7 @@ class ParallelTensor {
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
+ const TensorHandlePtr* tensor_data() const { return tensors_.data(); }
// If the `shape` argument to `FromTensorHandles` is specified, returns that.
//
diff --git a/tensorflow/c/eager/tfe_executor_internal.h b/tensorflow/c/eager/tfe_executor_internal.h
index 081b139bd34..7f55532af56 100644
--- a/tensorflow/c/eager/tfe_executor_internal.h
+++ b/tensorflow/c/eager/tfe_executor_internal.h
@@ -20,9 +20,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
struct TFE_Executor {
- explicit TFE_Executor(bool async, bool enable_streaming_enqueue)
- : owned_executor(
- new tensorflow::EagerExecutor(async, enable_streaming_enqueue)) {}
+ explicit TFE_Executor(bool async, bool enable_streaming_enqueue,
+ int in_flight_nodes_limit)
+ : owned_executor(new tensorflow::EagerExecutor(
+ async, enable_streaming_enqueue, in_flight_nodes_limit)) {}
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
: owned_executor(nullptr), unowned_executor(executor) {}
diff --git a/tensorflow/c/experimental/filesystem/BUILD b/tensorflow/c/experimental/filesystem/BUILD
index 4d8ff231ce7..6c5c43fbb46 100644
--- a/tensorflow/c/experimental/filesystem/BUILD
+++ b/tensorflow/c/experimental/filesystem/BUILD
@@ -6,6 +6,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
@@ -42,6 +43,7 @@ cc_library(
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
+ "//tensorflow/tsl/platform:errors",
],
)
diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.cc b/tensorflow/c/experimental/filesystem/modular_filesystem.cc
index 32b06697d77..b47748374fe 100644
--- a/tensorflow/c/experimental/filesystem/modular_filesystem.cc
+++ b/tensorflow/c/experimental/filesystem/modular_filesystem.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/util/ptr_util.h"
+#include "tensorflow/tsl/platform/errors.h"
// TODO(b/139060984): After all filesystems are converted, all calls to
// methods from `FileSystem` will have to be replaced to calls to private
@@ -561,8 +562,9 @@ Status RegisterFilesystemPlugin(const std::string& dso_path) {
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
- TF_RETURN_IF_ERROR(
- env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol),
+ "Failed to load TF_InitPlugin symbol for DSO: ", dso_path);
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo info;
diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
index 1d9bfc1a15f..bd2041b1d43 100644
--- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
+++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
@@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD
index 9d655fd43b5..90acb2bf389 100644
--- a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD
+++ b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD
@@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD
index fb2f99f44ff..2ac57f6a731 100644
--- a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD
+++ b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD
@@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD
index 90a99b05e38..1788cbd6551 100644
--- a/tensorflow/c/experimental/gradients/BUILD
+++ b/tensorflow/c/experimental/gradients/BUILD
@@ -5,10 +5,6 @@ load(
"if_libtpu",
"tf_cuda_cc_test",
)
-load(
- "//tensorflow/core/platform:build_config.bzl",
- "tf_kernel_tests_linkstatic",
-)
load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
@@ -16,6 +12,7 @@ load(
# Library of gradient functions.
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
@@ -59,7 +56,7 @@ cc_library(
"nn_grad.h",
],
visibility = [
- "//tensorflow:internal",
+ "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private.
],
deps = [
"//tensorflow/c/eager:abstract_tensor_handle",
@@ -118,7 +115,6 @@ tf_cuda_cc_test(
"custom_gradient_test.cc",
],
args = ["--heap_check="], # TODO(b/174752220): Remove
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/c:tf_status_helper",
@@ -144,10 +140,7 @@ filegroup(
"nn_grad.h",
"not_differentiable.h",
],
- visibility = [
- "//tensorflow/core:__pkg__",
- "//tensorflow/python:__pkg__",
- ],
+ visibility = ["//tensorflow/python:__pkg__"],
)
cc_library(
@@ -156,7 +149,7 @@ cc_library(
srcs = ["grad_test_helper.cc"],
hdrs = ["grad_test_helper.h"],
visibility = [
- "//tensorflow:internal",
+ "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private.
],
deps = [
"//tensorflow/c/eager:gradient_checker",
@@ -175,7 +168,6 @@ tf_cuda_cc_test(
"nn_grad_test.cc",
],
args = ["--heap_check="], # TODO(b/174752220): Remove
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156,
deps = [
":grad_test_helper",
@@ -202,7 +194,6 @@ tf_cuda_cc_test(
"math_grad_test.cc",
],
args = ["--heap_check="], # TODO(b/174752220): Remove
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156,
deps = [
":grad_test_helper",
@@ -229,7 +220,6 @@ tf_cuda_cc_test(
"array_grad_test.cc",
],
args = ["--heap_check="], # TODO(b/174752220): Remove
- linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156,
deps = [
":grad_test_helper",
diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD
index 123f1908020..c29b7929d43 100644
--- a/tensorflow/c/experimental/gradients/tape/BUILD
+++ b/tensorflow/c/experimental/gradients/tape/BUILD
@@ -2,6 +2,7 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/grappler/BUILD b/tensorflow/c/experimental/grappler/BUILD
index 68bdcdcda70..482ec08efed 100644
--- a/tensorflow/c/experimental/grappler/BUILD
+++ b/tensorflow/c/experimental/grappler/BUILD
@@ -8,6 +8,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD
new file mode 100644
index 00000000000..890477266ea
--- /dev/null
+++ b/tensorflow/c/experimental/next_pluggable_device/BUILD
@@ -0,0 +1,34 @@
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ licenses = ["notice"],
+)
+
+cc_library(
+ name = "c_api",
+ srcs = ["c_api.cc"],
+ hdrs = ["c_api.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/c:c_api",
+ "//tensorflow/c:kernels",
+ "//tensorflow/c:kernels_experimental_hdrs",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/c:tf_status_internal",
+ "//tensorflow/c:tf_tensor_internal",
+ "//tensorflow/compiler/jit:xla_launch_util",
+ "//tensorflow/compiler/xla/pjrt:pjrt_c_api_client",
+ "//tensorflow/compiler/xla/pjrt:pjrt_client",
+ "//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/common_runtime/next_pluggable_device",
+ "//tensorflow/core/common_runtime/next_pluggable_device:plugin_resource",
+ "//tensorflow/core/platform:status",
+ "//tensorflow/core/tfrt/common:async_value_tensor",
+ "//tensorflow/core/tfrt/common:pjrt_util",
+ "//tensorflow/tsl/distributed_runtime/coordination:coordination_service_agent",
+ "//tensorflow/tsl/platform:errors",
+ "//tensorflow/tsl/platform:statusor",
+ ],
+)
diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.cc b/tensorflow/c/experimental/next_pluggable_device/c_api.cc
new file mode 100644
index 00000000000..1ff6e091507
--- /dev/null
+++ b/tensorflow/c/experimental/next_pluggable_device/c_api.cc
@@ -0,0 +1,333 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/experimental/next_pluggable_device/c_api.h"
+
+#include
+#include
+#include
+#include
+#include
+
+#include "tensorflow/c/kernels_experimental.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/c/tf_status_internal.h"
+#include "tensorflow/c/tf_tensor.h"
+#include "tensorflow/c/tf_tensor_internal.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h"
+#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tfrt/common/async_value_tensor.h"
+#include "tensorflow/core/tfrt/common/pjrt_util.h"
+#include "tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.h"
+#include "tensorflow/tsl/platform/errors.h"
+#include "tensorflow/tsl/platform/statusor.h"
+
+TF_Device* TF_GetDevice(TF_OpKernelContext* ctx) {
+ auto* cc_ctx = reinterpret_cast(ctx);
+ return reinterpret_cast(cc_ctx->device());
+}
+
+size_t TF_GetDeviceOrdinal(TF_Device* device) {
+ // TODO(chuanhao): make GetDeviceOrdinal a virtual member function in the base
+ // device class, instead of casting to `NextPluggableDevice`.
+ auto cc_device = reinterpret_cast(device);
+ return cc_device->GetDeviceOrdinal();
+}
+
+// -------------------------- Resource ---------------------------------------
+void TF_CreatePluginResource(TF_OpKernelContext* ctx,
+ const char* container_name,
+ const char* plugin_resource_name,
+ void* plugin_resource, void (*delete_func)(void*),
+ TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast(ctx);
+ tensorflow::PluginResource* cc_resource_ptr = new tensorflow::PluginResource(
+ plugin_resource, plugin_resource_name, delete_func);
+ auto cc_status =
+ cc_ctx->resource_manager()->Create(
+ container_name, plugin_resource_name, cc_resource_ptr);
+ Set_TF_Status_from_Status(status, cc_status);
+}
+
+void TF_LookupOrCreatePluginResource(
+ TF_OpKernelContext* ctx, const char* container_name,
+ const char* plugin_resource_name, void** result_plugin_resource,
+ void* (*create_func)(void*), void* create_func_args,
+ void (*delete_func)(void*), TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast(ctx);
+ auto* resource_mgr = cc_ctx->resource_manager();
+ tensorflow::core::RefCountPtr
+ tf_plugin_resource_ptr;
+ tensorflow::PluginResource* tf_plugin_resource = nullptr;
+
+ auto cc_status = resource_mgr->LookupOrCreate(
+ container_name, plugin_resource_name, &tf_plugin_resource,
+ [plugin_resource_name, create_func, create_func_args,
+ delete_func](tensorflow::PluginResource** new_resource) {
+ void* opaque_plugin_resource = create_func(create_func_args);
+ *new_resource = new tensorflow::PluginResource(
+ opaque_plugin_resource, plugin_resource_name, delete_func);
+ return tensorflow::OkStatus();
+ });
+
+ if (cc_status.ok()) {
+ tf_plugin_resource_ptr.reset(tf_plugin_resource);
+ *result_plugin_resource = tf_plugin_resource_ptr->GetOpaquePluginResource();
+ } else {
+ *result_plugin_resource = nullptr;
+ }
+ Set_TF_Status_from_Status(status, cc_status);
+}
+
+// ------------------------- VariableInfo ------------------------------------
+struct TF_VariableInfo {
+ TF_VariableInfo() = delete;
+ // TF_VariableInfo is constructed here by TensorFlow, and will be passed to
+ // plugin as a opaque pointer. Plugin will need to call C APIs below to
+ // operate on TF_VaribleInfo (such as allocate temp tensor for the `var` held
+ // by the underlying tensorflow::VariableInfo.
+ TF_VariableInfo(int index, const std::string& name, tensorflow::Var* var) {
+ var_info = tensorflow::VariableInfo{index, name, var};
+ }
+
+ tensorflow::VariableInfo var_info{0, "", nullptr};
+};
+
+TF_VariableInfo* TF_CreateVariableInfoFromContext(TF_OpKernelContext* ctx,
+ int index,
+ TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast(ctx);
+ const tensorflow::Tensor& arg_tensor = cc_ctx->input(index);
+ tsl::Status cc_status;
+ if (arg_tensor.dtype() != tensorflow::DT_RESOURCE) {
+ cc_status = tsl::errors::InvalidArgument(
+ "Trying to obtain resource handle from Input[", index,
+ "], which is not type DT_RESOURCE.");
+ Set_TF_Status_from_Status(status, cc_status);
+ return nullptr;
+ }
+ const tensorflow::ResourceHandle& handle =
+ arg_tensor.flat()(0);
+ tensorflow::Var* variable;
+ cc_status = tensorflow::LookupResource(cc_ctx, handle, &variable);
+ return new TF_VariableInfo(index, handle.name(), variable);
+}
+
+void TF_LockVariableInfos(TF_VariableInfo** vars, int num_vars,
+ TF_Status* status) {
+ std::vector variable_ptrs;
+ variable_ptrs.reserve(num_vars);
+ for (int i = 0; i < num_vars; ++i) {
+ variable_ptrs.push_back(&(vars[i]->var_info));
+ }
+ tsl::Status cc_status = LockVariables(absl::MakeSpan(variable_ptrs));
+ tsl::Set_TF_Status_from_Status(status, cc_status);
+}
+
+void TF_AllocateTempForVariableInfo(TF_OpKernelContext* ctx,
+ TF_VariableInfo* var_info,
+ TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast(ctx);
+ tsl::Status cc_status;
+ if (var_info == nullptr) {
+ cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL.");
+ Set_TF_Status_from_Status(status, cc_status);
+ return;
+ }
+ if (var_info->var_info.var() == nullptr) {
+ cc_status = tsl::errors::InvalidArgument(
+ "VariableInfo does not track a resource variable.");
+ Set_TF_Status_from_Status(status, cc_status);
+ return;
+ }
+
+ cc_status = cc_ctx->allocate_temp(var_info->var_info.var()->tensor()->dtype(),
+ var_info->var_info.var()->tensor()->shape(),
+ var_info->var_info.var()->tensor());
+ Set_TF_Status_from_Status(status, cc_status);
+}
+
+TF_Tensor* TF_GetTensorFromVariableInfo(TF_VariableInfo* var_info,
+ TF_Status* status) {
+ tsl::Status cc_status;
+ if (var_info == nullptr) {
+ cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL.");
+ Set_TF_Status_from_Status(status, cc_status);
+ return nullptr;
+ }
+ if (var_info->var_info.var() == nullptr) {
+ cc_status = tsl::errors::InvalidArgument(
+ "VariableInfo does not track a resource variable.");
+ Set_TF_Status_from_Status(status, cc_status);
+ return nullptr;
+ }
+
+ tensorflow::Tensor* tensor = var_info->var_info.var()->tensor();
+ TF_Tensor* result_tensor =
+ tensorflow::TF_TensorFromTensor(*tensor, &cc_status);
+ Set_TF_Status_from_Status(status, cc_status);
+ return result_tensor;
+}
+
+void TF_DeleteVariableInfo(TF_VariableInfo* var_info) {
+ if (var_info != nullptr) {
+ delete var_info;
+ }
+}
+
+// --------------------- Coordination service --------------------------------
+TF_CoordinationServiceAgent* TF_GetCoordinationServiceAgent(
+ TF_OpKernelContext* ctx) {
+ auto* cc_ctx = reinterpret_cast(ctx);
+ return reinterpret_cast(
+ cc_ctx->coordination_service_agent());
+}
+
+bool TF_CoordinationServiceIsInitialized(TF_CoordinationServiceAgent* agent) {
+ if (agent == nullptr) return false;
+ auto* cc_agent = reinterpret_cast(agent);
+ return cc_agent->IsInitialized();
+}
+
+void TF_CoordinationServiceInsertKeyValue(const char* key, const char* value,
+ TF_CoordinationServiceAgent* agent,
+ TF_Status* status) {
+ auto* cc_agent = reinterpret_cast(agent);
+ tsl::Status cc_status = cc_agent->InsertKeyValue(key, value);
+ tsl::Set_TF_Status_from_Status(status, cc_status);
+}
+
+TF_Buffer* TF_CoordinationServiceGetKeyValue(const char* key,
+ TF_CoordinationServiceAgent* agent,
+ TF_Status* status) {
+ auto* cc_agent = reinterpret_cast(agent);
+ auto value = cc_agent->GetKeyValue(key);
+ tsl::Set_TF_Status_from_Status(status, value.status());
+ if (!value.ok()) {
+ return nullptr;
+ }
+ // Caller is responsible to call `TF_DeleteBuffer` to release the buffer.
+ TF_Buffer* result = TF_NewBuffer();
+ const std::string& value_str = *value;
+ void* data = malloc(value_str.length());
+ value_str.copy(static_cast(data), value_str.length(), 0);
+ result->data = data;
+ result->length = value_str.length();
+ result->data_deallocator = [](void* data, size_t length) { free(data); };
+ return result;
+}
+
+void TF_CoordinationServiceDeleteKeyValue(const char* key,
+ TF_CoordinationServiceAgent* agent,
+ TF_Status* status) {
+ auto* cc_agent = reinterpret_cast(agent);
+ tsl::Status cc_status = cc_agent->DeleteKeyValue(key);
+ tsl::Set_TF_Status_from_Status(status, cc_status);
+}
+
+// ---------------------------- PJRT -----------------------------------------
+void TF_CreateAndSetPjRtCApiClient(const char* device_type, TF_Status* status) {
+ tsl::StatusOr> pjrt_client =
+ xla::GetCApiClient(device_type);
+ if (!pjrt_client.ok()) {
+ tensorflow::Set_TF_Status_from_Status(status, pjrt_client.status());
+ return;
+ }
+
+ tsl::Status s = tensorflow::SetPjRtClientInTFGlobalResourceManager(
+ tensorflow::DeviceType(device_type), std::move(*pjrt_client));
+ tsl::Set_TF_Status_from_Status(status, s);
+}
+
+PJRT_Client* TF_GetPjRtCClient(const char* device_type, TF_Status* status) {
+ tsl::StatusOr pjrt_client =
+ tensorflow::GetOrCreatePjRtClient(tensorflow::DeviceType(device_type));
+ if (!pjrt_client.ok()) {
+ tensorflow::Set_TF_Status_from_Status(status, pjrt_client.status());
+ return nullptr;
+ }
+ auto* pjrt_c_api_client =
+ tensorflow::down_cast(*pjrt_client);
+ if (pjrt_c_api_client == nullptr) {
+ tensorflow::Set_TF_Status_from_Status(
+ status, tsl::errors::Internal("PjRtClient for ", device_type,
+ " is not type PjRtCApiClient"));
+ return nullptr;
+ }
+ TF_SetStatus(status, TF_OK, "");
+ return pjrt_c_api_client->pjrt_c_client();
+}
+
+PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, TF_Status* status) {
+ tensorflow::Tensor tensor;
+ auto s = tensorflow::TF_TensorToTensor(c_tensor, &tensor);
+ if (!s.ok()) {
+ tensorflow::Set_TF_Status_from_Status(status, s);
+ return nullptr;
+ }
+ tensorflow::AsyncValueTensor* av_tensor =
+ tensorflow::AsyncValueTensor::FromTensor(&tensor);
+ if (av_tensor == nullptr || av_tensor->GetBuffer() == nullptr) {
+ tensorflow::Set_TF_Status_from_Status(
+ status,
+ tsl::errors::Internal("Input tensor does not have PjRtBuffer."));
+ return nullptr;
+ }
+ auto* c_api_buffer =
+ tensorflow::down_cast(av_tensor->GetBuffer().get());
+ if (c_api_buffer == nullptr) {
+ tensorflow::Set_TF_Status_from_Status(
+ status,
+ tsl::errors::Internal(
+ "The PjRtBuffer in the tensor is not type PjRtCApiBuffer."));
+ return nullptr;
+ }
+ TF_SetStatus(status, TF_OK, "");
+ return c_api_buffer->c_buffer();
+}
+
+void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer,
+ const char* device_type, TF_Status* status) {
+ tensorflow::Tensor tensor;
+ auto s = tensorflow::TF_TensorToTensor(c_tensor, &tensor);
+ if (!s.ok()) {
+ tensorflow::Set_TF_Status_from_Status(status, s);
+ return;
+ }
+ auto pjrt_client =
+ tensorflow::GetOrCreatePjRtClient(tensorflow::DeviceType(device_type));
+ if (!pjrt_client.ok()) {
+ tensorflow::Set_TF_Status_from_Status(status, pjrt_client.status());
+ return;
+ }
+ auto* pjrt_c_api_client =
+ tensorflow::down_cast(*pjrt_client);
+ if (pjrt_c_api_client == nullptr) {
+ tensorflow::Set_TF_Status_from_Status(
+ status, tsl::errors::Internal("PjRtClient for ", device_type,
+ " is not type PjRtCApiClient"));
+ return;
+ }
+ tensorflow::AsyncValueTensor* av_tensor =
+ tensorflow::AsyncValueTensor::FromTensor(&tensor);
+ av_tensor->SetBuffer(
+ std::make_unique(pjrt_c_api_client, c_buffer));
+ TF_SetStatus(status, TF_OK, "");
+}
diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.h b/tensorflow/c/experimental/next_pluggable_device/c_api.h
new file mode 100644
index 00000000000..e577f02a595
--- /dev/null
+++ b/tensorflow/c/experimental/next_pluggable_device/c_api.h
@@ -0,0 +1,153 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_
+#define TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/kernels.h"
+#include "tensorflow/c/kernels_experimental.h"
+#include "tensorflow/c/tf_buffer.h"
+#include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
+
+// --------------------------------------------------------------------------
+// C API for device. The API is under active development and eventually
+// should allow registering a plugin device with TensorFlow.
+
+// Macro to control visibility of exported symbols in the shared library (.so,
+// .dylib, .dll).
+// This duplicates the TF_EXPORT macro definition in
+// tensorflow/core/platform/macros.h in order to keep this .h file independent
+// of any other includes.
+#ifdef SWIG
+#define TF_CAPI_EXPORT
+#else
+#if defined(_WIN32)
+#ifdef TF_COMPILE_LIBRARY
+#define TF_CAPI_EXPORT __declspec(dllexport)
+#else
+#define TF_CAPI_EXPORT __declspec(dllimport)
+#endif // TF_COMPILE_LIBRARY
+#else
+#define TF_CAPI_EXPORT __attribute__((visibility("default")))
+#endif // _WIN32
+#endif // SWIG
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// TF_Device is a C wrapper to the C++ TF Device class. This is to be passed
+// through TF_OpKernelContext, and is opaque to plugin.
+typedef struct TF_Device TF_Device;
+
+typedef struct TF_VariableInfo TF_VariableInfo;
+
+// Returns a `TF_Device` pointer, which actually points to a C++ `Device`.
+// Currently we only allow `NextPluggableDevice` to be casted as `TF_Device`,
+// but in theory every this is a C API for every kind of device.
+TF_CAPI_EXPORT extern TF_Device* TF_GetDevice(TF_OpKernelContext* ctx);
+
+TF_CAPI_EXPORT extern size_t TF_GetDeviceOrdinal(TF_Device* device);
+
+// -------------------------- Resource ---------------------------------------
+// Create a `tensorflow::PluginResource` to the ResourceMgr provided by the
+// `ctx`. The `tensorflow::PluginResource` wraps a resource by plugin (as a
+// opaque pointer, since TensorFlow cannot parse it). `delete_func` is needed
+// for ResourceMgr to clean up the resource. `status` will be set.
+TF_CAPI_EXPORT extern void TF_CreatePluginResource(
+ TF_OpKernelContext* ctx, const char* container_name,
+ const char* plugin_resource_name, void* plugin_resource,
+ void (*delete_func)(void*), TF_Status* status);
+
+// If the ResourceMgr provided by the `ctx` has a resource
+// `plugin_resource_name`, returns it in `*result_plugin_resource`. Otherwise,
+// invokes create_func to create the resource. `delete_func` is needed for
+// ResourceMgr to clean up the resource. `status` will be set. If `status` is
+// not OK, `*result_plugin_resource` will be set as nullptr.
+//
+// Caller does not take ownership of the `plugin_resource`.
+TF_CAPI_EXPORT extern void TF_LookupOrCreatePluginResource(
+ TF_OpKernelContext* ctx, const char* container_name,
+ const char* plugin_resource_name, void** result_plugin_resource,
+ void* (*create_func)(void*), void* create_func_args,
+ void (*delete_func)(void*), TF_Status* status);
+
+// ------------------------- VariableInfo ------------------------------------
+TF_CAPI_EXPORT extern TF_VariableInfo* TF_CreateVariableInfoFromContext(
+ TF_OpKernelContext* ctx, int index, TF_Status* status);
+
+TF_CAPI_EXPORT extern void TF_LockVariableInfos(TF_VariableInfo** vars,
+ int num_vars,
+ TF_Status* status);
+
+TF_CAPI_EXPORT extern void TF_AllocateTempForVariableInfo(
+ TF_OpKernelContext* ctx, TF_VariableInfo* var_info, TF_Status* status);
+
+TF_CAPI_EXPORT extern TF_Tensor* TF_GetTensorFromVariableInfo(
+ TF_VariableInfo* var_info, TF_Status* status);
+
+TF_CAPI_EXPORT extern void TF_DeleteVariableInfo(TF_VariableInfo* var_info);
+
+// --------------------- Coordination service --------------------------------
+// Returns a not owning pointer to the coordination service agent, which is
+// opaque to plugin. Plugin OpKernels need to use the accompanying C APIs to
+// access coordination service functionalities.
+TF_CAPI_EXPORT extern TF_CoordinationServiceAgent*
+TF_GetCoordinationServiceAgent(TF_OpKernelContext* ctx);
+
+// Returns true if the coordination service agent has been initialized.
+TF_CAPI_EXPORT extern bool TF_CoordinationServiceIsInitialized(
+ TF_CoordinationServiceAgent* agent);
+
+TF_CAPI_EXPORT extern void TF_CoordinationServiceInsertKeyValue(
+ const char* key, const char* value, TF_CoordinationServiceAgent* agent,
+ TF_Status* status);
+
+// Obtains key-value from coorination service agent. The returned `TF_Buffer`
+// is a newly allocated buffer to hold the string key-value, and caller is
+// responsible for managing the lifetime. If error, `status` will be set and a
+// nullptr will be returned.
+TF_CAPI_EXPORT extern TF_Buffer* TF_CoordinationServiceGetKeyValue(
+ const char* key, TF_CoordinationServiceAgent* agent, TF_Status* status);
+
+TF_CAPI_EXPORT extern void TF_CoordinationServiceDeleteKeyValue(
+ const char* key, TF_CoordinationServiceAgent* agent, TF_Status* status);
+
+// ---------------------------- PJRT -----------------------------------------
+TF_CAPI_EXPORT extern void TF_CreateAndSetPjRtCApiClient(
+ const char* device_type, TF_Status* status);
+
+// Gets the `PJRT_Client*` stored in TF global ResourceManager.
+TF_CAPI_EXPORT extern PJRT_Client* TF_GetPjRtCClient(const char* device_type,
+ TF_Status* status);
+
+// Gets the `PJRT_Buffer*` stored in the tensor. The status will contain error
+// if the tensor does not have a `PjRtCApiBuffer`.
+TF_CAPI_EXPORT extern PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor,
+ TF_Status* status);
+
+// Creates a `PjRtCApiBuffer` with the `PJRT_Buffer*` passed in and set to the
+// tensor.
+TF_CAPI_EXPORT extern void TF_CreatePjRtBuffer(TF_Tensor* c_tensor,
+ PJRT_Buffer* c_buffer,
+ const char* device_type,
+ TF_Status* status);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_
diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD
index e5cf1c39f65..13f1c808d45 100644
--- a/tensorflow/c/experimental/ops/BUILD
+++ b/tensorflow/c/experimental/ops/BUILD
@@ -3,6 +3,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# Experimental ops. These will eventually be replaced by machine-generated versions.
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/ops/gen/BUILD b/tensorflow/c/experimental/ops/gen/BUILD
index 21e855dceb9..7ab0a9f49c5 100644
--- a/tensorflow/c/experimental/ops/gen/BUILD
+++ b/tensorflow/c/experimental/ops/gen/BUILD
@@ -4,6 +4,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/ops/gen/common/BUILD b/tensorflow/c/experimental/ops/gen/common/BUILD
index 2dcbc644cf0..a5618623bbd 100644
--- a/tensorflow/c/experimental/ops/gen/common/BUILD
+++ b/tensorflow/c/experimental/ops/gen/common/BUILD
@@ -4,6 +4,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/ops/gen/cpp/BUILD b/tensorflow/c/experimental/ops/gen/cpp/BUILD
index 7b9aa347198..d2fd0294adb 100644
--- a/tensorflow/c/experimental/ops/gen/cpp/BUILD
+++ b/tensorflow/c/experimental/ops/gen/cpp/BUILD
@@ -4,6 +4,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/ops/gen/cpp/golden/BUILD b/tensorflow/c/experimental/ops/gen/cpp/golden/BUILD
index 5180b86cece..86880db388b 100644
--- a/tensorflow/c/experimental/ops/gen/cpp/golden/BUILD
+++ b/tensorflow/c/experimental/ops/gen/cpp/golden/BUILD
@@ -1,4 +1,5 @@
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD
index 2d41ae84512..7589ea2d2f2 100644
--- a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD
+++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD
@@ -4,6 +4,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/ops/gen/cpp/views/BUILD b/tensorflow/c/experimental/ops/gen/cpp/views/BUILD
index 455c6cac143..46f61c89d8e 100644
--- a/tensorflow/c/experimental/ops/gen/cpp/views/BUILD
+++ b/tensorflow/c/experimental/ops/gen/cpp/views/BUILD
@@ -1,4 +1,5 @@
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/ops/gen/model/BUILD b/tensorflow/c/experimental/ops/gen/model/BUILD
index 04df5d61748..918acaabb6b 100644
--- a/tensorflow/c/experimental/ops/gen/model/BUILD
+++ b/tensorflow/c/experimental/ops/gen/model/BUILD
@@ -1,4 +1,5 @@
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//tensorflow/c/experimental/ops/gen:__subpackages__"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/pluggable_profiler/BUILD b/tensorflow/c/experimental/pluggable_profiler/BUILD
index 9fd79348de6..4e3de6a46c1 100644
--- a/tensorflow/c/experimental/pluggable_profiler/BUILD
+++ b/tensorflow/c/experimental/pluggable_profiler/BUILD
@@ -5,6 +5,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.default.bzl", "filegroup")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
@@ -61,8 +62,8 @@ cc_library(
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/platform:status",
- "//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/lib:profiler_interface",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+ "//tensorflow/tsl/profiler/protobuf:profiler_options_proto_cc",
],
)
diff --git a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc
index 6e8cc32e556..0efa257723b 100644
--- a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc
+++ b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/lib/profiler_factory.h"
#include "tensorflow/core/profiler/lib/profiler_interface.h"
-#include "tensorflow/core/profiler/profiler_options.pb.h"
+#include "tensorflow/tsl/profiler/protobuf/profiler_options.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h
index 103c0905f08..6dbbe4549ff 100644
--- a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h
+++ b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h
@@ -18,8 +18,8 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/lib/profiler_interface.h"
-#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/tsl/profiler/protobuf/profiler_options.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD
index 394c7de8b59..d72cf86a7bc 100644
--- a/tensorflow/c/experimental/saved_model/core/BUILD
+++ b/tensorflow/c/experimental/saved_model/core/BUILD
@@ -11,7 +11,9 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
+ # copybara:uncomment() "//learning/brain/tfrt/aot:__pkg__",
"//tensorflow/c:__subpackages__",
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
"//tensorflow/cc/experimental/libtf:__pkg__",
diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD
index 14fa051a4ab..cce725db3fc 100644
--- a/tensorflow/c/experimental/saved_model/core/ops/BUILD
+++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD
@@ -8,6 +8,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
# Restricting visibility for now
"//tensorflow/c/experimental/saved_model/core:__subpackages__",
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
index 3c2050e79ec..ab7de9bae06 100644
--- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
@@ -3,9 +3,11 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# This package contains classes corresponding to Revived SavedObjectGraph types
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
# Restricting visibility for now
"//tensorflow/c/experimental/saved_model/core:__pkg__",
+ # copybara:uncomment "//learning/brain/tfrt/aot:__pkg__",
],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
index 2a4297e2b67..660a417be8f 100644
--- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
+++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
@@ -80,51 +80,6 @@ Status ConstantFromSavedConstant(
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
}
-// Finds the "signatures" object in the object graph, and fills a mapping of
-// each signature's name to the corresponding function's node in the object
-// graph.
-Status GetSignaturesMap(const SavedObjectGraph& saved_objects,
- gtl::FlatMap* signatures_map) {
- if (saved_objects.nodes().empty()) {
- return errors::FailedPrecondition("Saved Object Graph was empty.");
- }
- const SavedObject& root = saved_objects.nodes(0);
- const SavedObject* signatures = nullptr;
- for (const auto& child : root.children()) {
- if (child.local_name() == "signatures") {
- if (child.node_id() >= saved_objects.nodes().size()) {
- return errors::FailedPrecondition(
- "Signature object had child node id ", child.node_id(),
- " which exceeds the size of the set of nodes");
- }
- signatures = &saved_objects.nodes(child.node_id());
- }
- }
-
- // Some basic sanity checks that this object is actually our "signatures" map
- if (signatures == nullptr) {
- // This is where the "signatures" attribute is always set:
- // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109
- return errors::FailedPrecondition(
- "SavedObjectGraph's root object must have a child 'signatures' object");
- }
- if (signatures->kind_case() != SavedObject::kUserObject) {
- return errors::FailedPrecondition(
- "Signatures must be a SavedObject of type UserObject.");
- }
- if (signatures->user_object().identifier() != "signature_map") {
- // This is where the string comes from:
- // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220
- return errors::FailedPrecondition(
- "Signatures SavedObject must have identifier 'signature_map'.");
- }
-
- for (const auto& child : signatures->children()) {
- (*signatures_map)[child.local_name()] = child.node_id();
- }
- return Status();
-}
-
// Perform some basic sanity checks on SavedConcreteFunction's input and
// output signatures with respect to the corresponding FunctionDef's input
// and output args.
@@ -183,6 +138,50 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
return Status();
}
+} // namespace
+
+Status GetSignaturesMap(const SavedObjectGraph& saved_objects,
+ gtl::FlatMap* signatures_map) {
+ if (saved_objects.nodes().empty()) {
+ return errors::FailedPrecondition("Saved Object Graph was empty.");
+ }
+ const SavedObject& root = saved_objects.nodes(0);
+ const SavedObject* signatures = nullptr;
+ for (const auto& child : root.children()) {
+ if (child.local_name() == "signatures") {
+ if (child.node_id() >= saved_objects.nodes().size()) {
+ return errors::FailedPrecondition(
+ "Signature object had child node id ", child.node_id(),
+ " which exceeds the size of the set of nodes");
+ }
+ signatures = &saved_objects.nodes(child.node_id());
+ }
+ }
+
+ // Some basic sanity checks that this object is actually our "signatures" map
+ if (signatures == nullptr) {
+ // This is where the "signatures" attribute is always set:
+ // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109
+ return errors::FailedPrecondition(
+ "SavedObjectGraph's root object must have a child 'signatures' object");
+ }
+ if (signatures->kind_case() != SavedObject::kUserObject) {
+ return errors::FailedPrecondition(
+ "Signatures must be a SavedObject of type UserObject.");
+ }
+ if (signatures->user_object().identifier() != "signature_map") {
+ // This is where the string comes from:
+ // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220
+ return errors::FailedPrecondition(
+ "Signatures SavedObject must have identifier 'signature_map'.");
+ }
+
+ for (const auto& child : signatures->children()) {
+ (*signatures_map)[child.local_name()] = child.node_id();
+ }
+ return Status();
+}
+
Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) {
// We only allow loading functions that have an annotated input signature,
// which means there is 1:1 correspondence between tf.function
@@ -198,8 +197,6 @@ Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) {
return Status();
}
-} // namespace
-
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
const std::string& saved_model_dir,
absl::Span assets,
@@ -438,9 +435,11 @@ Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
resource_revival_state.device = node.resource().device();
objects->restored_resources[i] = std::move(resource_revival_state);
} else if (node.kind_case() == SavedObject::kFunction) {
- // Get the SavedFunction node and validate it has a single concrete func.
+ // Get the SavedFunction node and skip if it has no concrete functions.
const SavedFunction& saved_function = node.function();
- TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function));
+ if (saved_function.concrete_functions_size() < 1) {
+ continue;
+ }
// Retrieve related function information.
const std::string& function_name = saved_function.concrete_functions(0);
diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h
index db45e28087f..34b4499621c 100644
--- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h
+++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h
@@ -94,6 +94,15 @@ gtl::FlatMap NodeToAttrMap(
gtl::FlatMap
FunctionNameToFunctionDefMap(const FunctionDefLibrary& library);
+// Finds the "signatures" object in the object graph, and fills a mapping of
+// each signature's name to the corresponding function's node in the object
+// graph.
+Status GetSignaturesMap(const SavedObjectGraph& saved_objects,
+ gtl::FlatMap* signatures_map);
+
+// Validates the `saved_function`.
+Status ValidateSingleConcreteFunction(const SavedFunction& saved_function);
+
// Walks through the SavedObjectGraph in metagraph, and restores all nodes
// (except "UserDefinedObjects") with their corresponding type in
// "PartiallyRevivedObjects".
diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD
index 2647a822f93..d6dc1f202b0 100644
--- a/tensorflow/c/experimental/saved_model/internal/BUILD
+++ b/tensorflow/c/experimental/saved_model/internal/BUILD
@@ -20,7 +20,10 @@ load(
"tf_copts",
)
-package(licenses = ["notice"])
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ licenses = ["notice"],
+)
cc_library(
name = "concrete_function",
diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD
index 49acc9274fc..ab1a6e3689e 100644
--- a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD
+++ b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD
@@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup")
load("//tensorflow:strict.default.bzl", "py_strict_binary")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD
index 71fd46ab889..6a711ae1738 100644
--- a/tensorflow/c/experimental/saved_model/public/BUILD
+++ b/tensorflow/c/experimental/saved_model/public/BUILD
@@ -11,6 +11,7 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
# This is intentionally public
default_visibility = [
"//visibility:public",
diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD
index 849c0f2c22b..d06c536f671 100644
--- a/tensorflow/c/experimental/stream_executor/BUILD
+++ b/tensorflow/c/experimental/stream_executor/BUILD
@@ -9,6 +9,7 @@ load(
load("//tensorflow:tensorflow.default.bzl", "filegroup")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
@@ -43,13 +44,12 @@ cc_library(
"//tensorflow/compiler/xla/stream_executor:executor_cache",
"//tensorflow/compiler/xla/stream_executor:multi_platform_manager",
"//tensorflow/compiler/xla/stream_executor:platform",
- "//tensorflow/compiler/xla/stream_executor:stream_executor_internal",
"//tensorflow/compiler/xla/stream_executor:stream_executor_pimpl",
"//tensorflow/compiler/xla/stream_executor:timer",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/device:device_utils",
- "//tensorflow/core/platform:regexp",
"//tensorflow/core/platform:strcat",
+ "@com_google_absl//absl/functional:any_invocable",
],
)
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc
index c8a9670156b..2ba7d3cc953 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor.cc
+++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc
@@ -22,7 +22,9 @@ limitations under the License.
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include
+#include
+#include "absl/functional/any_invocable.h"
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/c_api_macros_internal.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
@@ -51,7 +53,7 @@ using tensorflow::StringPiece;
using OwnedTFStatus = tensorflow::TF_StatusPtr;
namespace {
-port::Status ValidateSPPlatform(const SP_Platform& platform) {
+tsl::Status ValidateSPPlatform(const SP_Platform& platform) {
TF_VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
TF_VALIDATE_NOT_NULL(SP_Platform, platform, name);
TF_VALIDATE_NOT_NULL(SP_Platform, platform, type);
@@ -63,7 +65,7 @@ port::Status ValidateSPPlatform(const SP_Platform& platform) {
return ::tensorflow::OkStatus();
}
-port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
+tsl::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
TF_VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns,
SP_PLATFORM_FNS_STRUCT_SIZE);
TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_device);
@@ -77,40 +79,40 @@ port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
return ::tensorflow::OkStatus();
}
-port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
+tsl::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
TF_VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE);
TF_VALIDATE_NOT_NULL(SP_TimerFns, timer_fns, nanoseconds);
return ::tensorflow::OkStatus();
}
-port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
+tsl::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
TF_VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats,
SP_ALLOCATORSTATS_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return ::tensorflow::OkStatus();
}
-port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
+tsl::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
TF_VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem,
SP_DEVICE_MEMORY_BASE_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return ::tensorflow::OkStatus();
}
-port::Status ValidateSPDevice(const SP_Device& device) {
+tsl::Status ValidateSPDevice(const SP_Device& device) {
TF_VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return ::tensorflow::OkStatus();
}
-port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
+tsl::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
TF_VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return ::tensorflow::OkStatus();
}
-port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
- const SP_Platform& platform) {
+tsl::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
+ const SP_Platform& platform) {
TF_VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se,
SP_STREAM_EXECUTOR_STRUCT_SIZE);
TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, allocate);
@@ -149,7 +151,7 @@ port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
return ::tensorflow::OkStatus();
}
-port::Status ValidateSEPlatformRegistrationParams(
+tsl::Status ValidateSEPlatformRegistrationParams(
const SE_PlatformRegistrationParams& params) {
TF_VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
@@ -193,7 +195,7 @@ DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
// Wrapper that allows passing std::function across C API.
struct HostCallbackContext {
- std::function callback;
+ absl::AnyInvocable callback;
};
// This wrapper allows calling `HostCallbackContext::callback` across C API.
@@ -201,7 +203,7 @@ struct HostCallbackContext {
// `callback_fn` to `host_callback` in `SP_StreamExecutor`.
void HostCallbackTrampoline(void* ctx, TF_Status* status) {
HostCallbackContext* host_ctx = static_cast(ctx);
- port::Status s = host_ctx->callback();
+ tsl::Status s = std::move(host_ctx->callback)();
Set_TF_Status_from_Status(status, s);
delete host_ctx;
}
@@ -226,14 +228,14 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
platform_fns_->destroy_device(platform_, &device_);
}
- port::Status Init(int device_ordinal, DeviceOptions device_options) override {
+ tsl::Status Init(int device_ordinal, DeviceOptions device_options) override {
return ::tensorflow::OkStatus();
}
DeviceMemoryBase Allocate(uint64 size, int64_t memory_space) override {
SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
stream_executor_->allocate(&device_, size, memory_space, &mem);
- port::Status status = ValidateSPDeviceMemoryBase(mem);
+ tsl::Status status = ValidateSPDeviceMemoryBase(mem);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
@@ -280,7 +282,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
if (!has_stats) {
return absl::nullopt;
}
- port::Status status = ValidateSPAllocatorStats(c_stats);
+ tsl::Status status = ValidateSPAllocatorStats(c_stats);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
return absl::nullopt;
@@ -310,38 +312,37 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
}
return true;
}
- port::Status SynchronousMemZero(DeviceMemoryBase* location,
- uint64 size) override {
+ tsl::Status SynchronousMemZero(DeviceMemoryBase* location,
+ uint64 size) override {
// TODO(annarev): figure out if we should support memzero/memset
// functionality by allocating on host and then copying to device.
- return port::UnimplementedError(
+ return tsl::errors::Unimplemented(
"SynchronousMemZero is not supported by pluggable device.");
}
- port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
- uint64 size) override {
- return port::UnimplementedError(
+ tsl::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
+ uint64 size) override {
+ return tsl::errors::Unimplemented(
"SynchronousMemSet is not supported by pluggable device.");
}
- port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
- const void* host_src, uint64 size) override {
+ tsl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src,
+ uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst);
stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src,
size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
- port::Status SynchronousMemcpy(void* host_dst,
- const DeviceMemoryBase& gpu_src,
- uint64 size) override {
+ tsl::Status SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src,
+ uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base,
size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
- port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
- const DeviceMemoryBase& gpu_src,
- uint64 size) override {
+ tsl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
+ const DeviceMemoryBase& gpu_src,
+ uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
@@ -349,8 +350,8 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
&device_mem_src, size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
- port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
- uint64 size) override {
+ tsl::Status MemZero(Stream* stream, DeviceMemoryBase* location,
+ uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast(stream->implementation())->Handle();
@@ -359,8 +360,8 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
c_status.get());
return StatusFromTF_Status(c_status.get());
}
- port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
- uint64 size) override {
+ tsl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
+ uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast(stream->implementation())->Handle();
@@ -369,8 +370,8 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
- port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
- uint32 pattern, uint64 size) override {
+ tsl::Status Memset32(Stream* stream, DeviceMemoryBase* location,
+ uint32 pattern, uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast(stream->implementation())->Handle();
@@ -424,27 +425,27 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
return true;
}
bool HostCallback(Stream* stream,
- std::function callback) override {
+ absl::AnyInvocable callback) override {
SP_Stream stream_handle =
static_cast(stream->implementation())->Handle();
- HostCallbackContext* ctx = new HostCallbackContext{callback};
+ HostCallbackContext* ctx = new HostCallbackContext{std::move(callback)};
return stream_executor_->host_callback(&device_, stream_handle,
&HostCallbackTrampoline, ctx);
}
- port::Status AllocateEvent(Event* event) override {
+ tsl::Status AllocateEvent(Event* event) override {
DCHECK(event != nullptr);
return static_cast(event->implementation())->Create();
}
- port::Status DeallocateEvent(Event* event) override {
+ tsl::Status DeallocateEvent(Event* event) override {
static_cast(event->implementation())->Destroy();
return ::tensorflow::OkStatus();
}
- port::Status RecordEvent(Stream* stream, Event* event) override {
+ tsl::Status RecordEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle =
static_cast(stream->implementation())->Handle();
return static_cast(event->implementation())->Record(stream_handle);
}
- port::Status WaitForEvent(Stream* stream, Event* event) override {
+ tsl::Status WaitForEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle =
static_cast(stream->implementation())->Handle();
SP_Event event_handle =
@@ -452,7 +453,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
c_status.get());
- port::Status s = StatusFromTF_Status(c_status.get());
+ tsl::Status s = StatusFromTF_Status(c_status.get());
return s;
}
Event::Status PollForEventStatus(Event* event) override {
@@ -464,7 +465,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
}
bool AllocateStream(Stream* stream) override {
DCHECK(stream != nullptr);
- port::Status status =
+ tsl::Status status =
static_cast(stream->implementation())->Create();
// TODO(annarev): update AllocateStream to return status instead
// (similar to AllocateEvent).
@@ -488,7 +489,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
return true;
}
bool AllocateTimer(Timer* timer) override {
- port::Status status =
+ tsl::Status status =
static_cast(timer->implementation())->Create();
// TODO(annarev): change return value of AllocateTimer
// to status (similar to AllocateEvent).
@@ -525,7 +526,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
}
return true;
}
- port::Status BlockHostForEvent(Stream* stream, Event* event) {
+ tsl::Status BlockHostForEvent(Stream* stream, Event* event) {
OwnedTFStatus c_status(TF_NewStatus());
SP_Event event_handle =
static_cast(event->implementation())->Handle();
@@ -534,7 +535,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
return StatusFromTF_Status(c_status.get());
}
- port::Status BlockHostUntilDone(Stream* stream) override {
+ tsl::Status BlockHostUntilDone(Stream* stream) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast(stream->implementation())->Handle();
@@ -551,7 +552,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
stream_executor_->record_event(&device_, stream_handle, event_handle,
c_status.get());
- port::Status s = StatusFromTF_Status(c_status.get());
+ tsl::Status s = StatusFromTF_Status(c_status.get());
if (!s.ok()) {
stream_executor_->destroy_event(&device_, event_handle);
return s;
@@ -562,7 +563,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
return StatusFromTF_Status(c_status.get());
}
- port::Status GetStatus(Stream* stream) override {
+ tsl::Status GetStatus(Stream* stream) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast(stream->implementation())->Handle();
@@ -571,8 +572,8 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
return StatusFromTF_Status(c_status.get());
}
int PlatformDeviceCount() override { return visible_device_count_; }
- port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
- return port::UnimplementedError(
+ tsl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
+ return tsl::errors::Unimplemented(
"EnablePeerAccessTo is not supported by pluggable device.");
}
bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
@@ -587,7 +588,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
// Creates a new DeviceDescription object.
// Ownership is transferred to the caller.
- port::StatusOr> CreateDeviceDescription()
+ tsl::StatusOr> CreateDeviceDescription()
const override {
OwnedTFStatus c_status(TF_NewStatus());
@@ -679,7 +680,7 @@ CPlatform::~CPlatform() {
destroy_platform_fns_(&platform_fns_);
}
-port::StatusOr>
+tsl::StatusOr>
CPlatform::DescriptionForDevice(int ordinal) const {
// TODO(annarev): see if we can get StreamExecutor instance
// and call GetDeviceDescription. executor_cache_.Get would need
@@ -688,24 +689,24 @@ CPlatform::DescriptionForDevice(int ordinal) const {
builder.set_name(name_);
return builder.Build();
}
-port::StatusOr CPlatform::ExecutorForDevice(int ordinal) {
+tsl::StatusOr CPlatform::ExecutorForDevice(int ordinal) {
stream_executor::StreamExecutorConfig config;
config.ordinal = ordinal;
return GetExecutor(config);
}
-port::StatusOr CPlatform::ExecutorForDeviceWithPluginConfig(
+tsl::StatusOr CPlatform::ExecutorForDeviceWithPluginConfig(
int ordinal, const PluginConfig& plugin_config) {
StreamExecutorConfig config;
config.ordinal = ordinal;
config.plugin_config = plugin_config;
return GetExecutor(config);
}
-port::StatusOr CPlatform::GetExecutor(
+tsl::StatusOr CPlatform::GetExecutor(
const StreamExecutorConfig& config) {
return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
}
-port::StatusOr> CPlatform::GetUncachedExecutor(
+tsl::StatusOr> CPlatform::GetUncachedExecutor(
const StreamExecutorConfig& config) {
// Fill device creation params
SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE};
@@ -734,9 +735,8 @@ port::StatusOr> CPlatform::GetUncachedExecutor(
return result;
}
-port::Status InitStreamExecutorPlugin(void* dso_handle,
- std::string* device_type,
- std::string* platform_name) {
+tsl::Status InitStreamExecutorPlugin(void* dso_handle, std::string* device_type,
+ std::string* platform_name) {
tensorflow::Env* env = tensorflow::Env::Default();
// Step 1: Load symbol for `TF_InitPlugin`
@@ -749,9 +749,9 @@ port::Status InitStreamExecutorPlugin(void* dso_handle,
return InitStreamExecutorPlugin(init_fn, device_type, platform_name);
}
-port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
- std::string* device_type,
- std::string* platform_name) {
+tsl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
+ std::string* device_type,
+ std::string* platform_name) {
SE_PlatformRegistrationParams params{
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
@@ -804,7 +804,7 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
std::move(platform), params.destroy_platform, std::move(platform_fns),
params.destroy_platform_fns, std::move(device_fns), std::move(se),
std::move(timer_fns)));
- SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
+ TF_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(cplatform)));
// TODO(annarev): Return `use_bfc_allocator` value in some way so that it is
// available in `PluggableDeviceProcessState` once the latter is checked in.
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
index 7246dde2660..ad8a77d61fa 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
+++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/xla/stream_executor/executor_cache.h"
-#include "tensorflow/compiler/xla/stream_executor/lib/status.h"
#include "tensorflow/compiler/xla/stream_executor/platform.h"
namespace stream_executor {
@@ -33,15 +32,14 @@ typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const,
// Registers StreamExecutor platform. `device_type` and `platform_name` are
// output parameters.
-port::Status InitStreamExecutorPlugin(void* dso_handle,
- std::string* device_type,
- std::string* platform_name);
+tsl::Status InitStreamExecutorPlugin(void* dso_handle, std::string* device_type,
+ std::string* platform_name);
// Allow registering a StreamExecutor plugin using a function (used for
// testing).
-port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
- std::string* device_type,
- std::string* platform_name);
+tsl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
+ std::string* device_type,
+ std::string* platform_name);
// This file implements core stream executor base classes in terms of
// the C API defined in stream_executor.h. A class "CSomething" represents a
@@ -71,14 +69,14 @@ class CPlatform : public Platform {
}
bool UseBfcAllocator() const { return platform_.use_bfc_allocator; }
bool ForceMemoryGrowth() const { return platform_.force_memory_growth; }
- port::StatusOr> DescriptionForDevice(
+ tsl::StatusOr> DescriptionForDevice(
int ordinal) const override;
- port::StatusOr ExecutorForDevice(int ordinal) override;
- port::StatusOr ExecutorForDeviceWithPluginConfig(
+ tsl::StatusOr ExecutorForDevice(int ordinal) override;
+ tsl::StatusOr ExecutorForDeviceWithPluginConfig(
int ordinal, const PluginConfig& plugin_config) override;
- port::StatusOr GetExecutor(
+ tsl::StatusOr GetExecutor(
const StreamExecutorConfig& config) override;
- port::StatusOr> GetUncachedExecutor(
+ tsl::StatusOr> GetUncachedExecutor(
const StreamExecutorConfig& config) override;
// Trace listener is not supported
@@ -110,10 +108,10 @@ class CStream : public internal::StreamInterface {
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
- port::Status Create() {
+ tsl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
- port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
+ tsl::Status s = tensorflow::StatusFromTF_Status(c_status.get());
return s;
}
@@ -140,13 +138,13 @@ class CEvent : public internal::EventInterface {
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
- port::Status Create() {
+ tsl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
return tensorflow::StatusFromTF_Status(c_status.get());
}
- port::Status Record(SP_Stream stream_handle) {
+ tsl::Status Record(SP_Stream stream_handle) {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get());
@@ -178,7 +176,7 @@ class CTimer : public internal::TimerInterface {
timer_fns_(timer_fns) {}
~CTimer() override { Destroy(); }
- port::Status Create() {
+ tsl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return tensorflow::StatusFromTF_Status(c_status.get());
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
index 8b82121c51d..cf21374c48f 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
+++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
@@ -38,17 +38,17 @@ TEST(StreamExecutor, SuccessfulRegistration) {
test_util::PopulateDefaultPlatformRegistrationParams(params);
};
std::string device_type, platform_name;
- port::Status status =
+ tsl::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
TF_ASSERT_OK(status);
- port::StatusOr maybe_platform =
+ tsl::StatusOr maybe_platform =
MultiPlatformManager::PlatformWithName("MY_DEVICE");
TF_ASSERT_OK(maybe_platform.status());
Platform* platform = std::move(maybe_platform).value();
ASSERT_EQ(platform->Name(), test_util::kDeviceName);
ASSERT_EQ(platform->VisibleDeviceCount(), test_util::kDeviceCount);
- port::StatusOr maybe_executor =
+ tsl::StatusOr maybe_executor =
platform->ExecutorForDevice(0);
TF_ASSERT_OK(maybe_executor.status());
}
@@ -62,7 +62,7 @@ TEST(StreamExecutor, NameNotSet) {
};
std::string device_type, platform_name;
- port::Status status =
+ tsl::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
@@ -77,7 +77,7 @@ TEST(StreamExecutor, InvalidNameWithSemicolon) {
};
std::string device_type, platform_name;
- port::Status status =
+ tsl::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(
@@ -94,7 +94,7 @@ TEST(StreamExecutor, InvalidNameWithSlash) {
};
std::string device_type, platform_name;
- port::Status status =
+ tsl::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(status.error_message(),
@@ -110,7 +110,7 @@ TEST(StreamExecutor, CreateDeviceNotSet) {
};
std::string device_type, platform_name;
- port::Status status =
+ tsl::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(),
@@ -126,7 +126,7 @@ TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) {
};
std::string device_type, platform_name;
- port::Status status =
+ tsl::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(
@@ -152,7 +152,7 @@ class StreamExecutorTest : public ::testing::Test {
platform_, test_util::DestroyPlatform, platform_fns_,
test_util::DestroyPlatformFns, device_fns_, se_, timer_fns_);
}
- port::StatusOr maybe_executor =
+ tsl::StatusOr maybe_executor =
cplatform_->ExecutorForDevice(ordinal);
TF_CHECK_OK(maybe_executor.status());
return std::move(maybe_executor).value();
@@ -724,7 +724,7 @@ TEST_F(StreamExecutorTest, HostCallbackOk) {
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
- std::function callback = []() -> port::Status {
+ std::function callback = []() -> tsl::Status {
return ::tensorflow::OkStatus();
};
stream.ThenDoHostCallbackWithStatus(callback);
@@ -744,8 +744,8 @@ TEST_F(StreamExecutorTest, HostCallbackError) {
StreamExecutor* executor = GetExecutor(0);
Stream stream(executor);
stream.Init();
- std::function callback = []() -> port::Status {
- return port::UnimplementedError("Unimplemented");
+ std::function callback = []() -> tsl::Status {
+ return tsl::errors::Unimplemented("Unimplemented");
};
stream.ThenDoHostCallbackWithStatus(callback);
ASSERT_FALSE(stream.ok());
diff --git a/tensorflow/c/experimental/stream_executor/test/BUILD b/tensorflow/c/experimental/stream_executor/test/BUILD
index e3795a2715b..2a4d40b3e79 100644
--- a/tensorflow/c/experimental/stream_executor/test/BUILD
+++ b/tensorflow/c/experimental/stream_executor/test/BUILD
@@ -6,6 +6,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
index c3a54a46b3c..85b2433ac43 100644
--- a/tensorflow/c/kernels.cc
+++ b/tensorflow/c/kernels.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/kernels.h"
#include
+#include
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_api_macros.h"
@@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_handle.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/types.h"
// Required for IS_MOBILE_PLATFORM definition
@@ -295,6 +297,13 @@ void TF_InputRange(TF_OpKernelContext* ctx, const char* name,
tensorflow::Set_TF_Status_from_Status(args->status, status);
}
+TF_DataType TF_InputDatatype(TF_OpKernelContext* ctx, int index) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ CHECK_GE(index, 0); // Crash OK
+ CHECK_LT(index, cc_ctx->num_inputs()); // Crash OK
+ return static_cast(cc_ctx->input_dtype(index));
+}
+
void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
@@ -354,6 +363,18 @@ void TF_GetSerializedConfigProto(TF_OpKernelContext* ctx,
tensorflow::Set_TF_Status_from_Status(status, cc_status);
}
+void TF_GetSerializedResourceHandleProto(
+ TF_OpKernelContext* ctx, int i, TF_Buffer* serialized_resource_handle_proto,
+ TF_Status* status) {
+ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
+ const tensorflow::ResourceHandle& handle = HandleFromInput(cc_ctx, i);
+ tensorflow::ResourceHandleProto handle_proto;
+ handle.AsProto(&handle_proto);
+ auto cc_status = tensorflow::MessageToBuffer(
+ handle_proto, serialized_resource_handle_proto);
+ tensorflow::Set_TF_Status_from_Status(status, cc_status);
+}
+
void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
@@ -652,6 +673,18 @@ int64_t TF_GetIterId(TF_OpKernelContext* ctx) {
.iter_id;
}
+int64_t TF_GetStepId(TF_OpKernelContext* ctx) {
+ return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
+}
+
+int TF_GetDeviceId(TF_OpKernelContext* ctx) {
+ // TensorFlow always sets device in OpKernelContext.
+ auto* device =
+ reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->device();
+ if (!device->parsed_name().has_id) return -1;
+ return device->parsed_name().id;
+}
+
TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx) {
auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
TF_StringView opkernel_name_sv;
diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h
index e85dc9f252a..2e765b7dfaa 100644
--- a/tensorflow/c/kernels.h
+++ b/tensorflow/c/kernels.h
@@ -190,6 +190,11 @@ TF_CAPI_EXPORT extern void TF_InputRange(TF_OpKernelContext* ctx,
const char* name,
TF_InputRange_Args* args);
+// Returns the data type of the index-th input. If index < 0 or index >=
+// TF_NumInputs(ctx), the program aborts.
+TF_CAPI_EXPORT extern TF_DataType TF_InputDatatype(TF_OpKernelContext* ctx,
+ int index);
+
// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but
// TF_OK, ctx is left unmodified.
//
@@ -216,6 +221,11 @@ TF_CAPI_EXPORT extern void TF_GetSerializedConfigProto(
TF_OpKernelContext* ctx, TF_Buffer* serialized_config_proto,
TF_Status* status);
+// Retrieves a serialized ResourceHandleProto. Status will be set.
+TF_CAPI_EXPORT extern void TF_GetSerializedResourceHandleProto(
+ TF_OpKernelContext* ctx, int i, TF_Buffer* serialized_resource_handle_proto,
+ TF_Status* status);
+
// Notifies the given OpKernelConstruction that kernel construction has failed.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure(
TF_OpKernelConstruction* ctx, TF_Status* status);
@@ -253,6 +263,12 @@ TF_CAPI_EXPORT extern uint64_t TF_GetFrameId(TF_OpKernelContext* ctx);
// Returns the Iter ID of the given context.
TF_CAPI_EXPORT extern int64_t TF_GetIterId(TF_OpKernelContext* ctx);
+// Returns the Step ID of the given context.
+TF_CAPI_EXPORT extern int64_t TF_GetStepId(TF_OpKernelContext* ctx);
+
+// Returns the Device ID of the device that the context possesses.
+TF_CAPI_EXPORT extern int TF_GetDeviceId(TF_OpKernelContext* ctx);
+
// Returns the graph def version of the given context.
TF_CAPI_EXPORT extern int TF_GetGraphDefVersion(TF_OpKernelContext* ctx);
diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD
index 99fbcfabab4..93ed9a7880b 100644
--- a/tensorflow/c/kernels/BUILD
+++ b/tensorflow/c/kernels/BUILD
@@ -3,6 +3,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)
diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h
index 3e6121bf989..df0c1fb45b0 100644
--- a/tensorflow/c/tf_datatype.h
+++ b/tensorflow/c/tf_datatype.h
@@ -59,7 +59,7 @@ typedef enum TF_DataType {
TF_QINT8 = 11, // Quantized int8
TF_QUINT8 = 12, // Quantized uint8
TF_QINT32 = 13, // Quantized int32
- TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
+ TF_BFLOAT16 = 14, // Float32 truncated to 16 bits.
TF_QINT16 = 15, // Quantized int16
TF_QUINT16 = 16, // Quantized uint16
TF_UINT16 = 17,
@@ -69,6 +69,9 @@ typedef enum TF_DataType {
TF_VARIANT = 21,
TF_UINT32 = 22,
TF_UINT64 = 23,
+ TF_FLOAT8_E5M2 = 24, // 5 exponent bits, 2 mantissa bits.
+ TF_FLOAT8_E4M3FN = 25, // 4 exponent bits, 3 mantissa bits, finite-only, with
+ // 2 NaNs (0bS1111111).
} TF_DataType;
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
diff --git a/tensorflow/c/tf_status.cc b/tensorflow/c/tf_status.cc
index 2f774fa7977..686e09508ac 100644
--- a/tensorflow/c/tf_status.cc
+++ b/tensorflow/c/tf_status.cc
@@ -16,39 +16,21 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/status.h"
-using ::tensorflow::Status;
-using ::tensorflow::error::Code;
-using ::tensorflow::errors::IOError;
-
-TF_Status* TF_NewStatus() { return new TF_Status; }
-
-void TF_DeleteStatus(TF_Status* s) { delete s; }
+// Trampoline implementation to redirect to TSL. Kept here for backward
+// compatibility only.
+TF_Status* TF_NewStatus() { return TSL_NewStatus(); }
+void TF_DeleteStatus(TF_Status* s) { TSL_DeleteStatus(s); }
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
- if (code == TF_OK) {
- s->status = ::tensorflow::OkStatus();
- return;
- }
- s->status = Status(static_cast(code), tensorflow::StringPiece(msg));
+ TSL_SetStatus(s, TSL_Code(code), msg);
}
-
void TF_SetPayload(TF_Status* s, const char* key, const char* value) {
- s->status.SetPayload(key, value);
+ TSL_SetPayload(s, key, value);
}
-
void TF_SetStatusFromIOError(TF_Status* s, int error_code,
const char* context) {
- // TODO(b/139060984): Handle windows when changing its filesystem
- s->status = IOError(context, error_code);
-}
-
-TF_Code TF_GetCode(const TF_Status* s) {
- return static_cast(s->status.code());
-}
-
-const char* TF_Message(const TF_Status* s) {
- return s->status.error_message().c_str();
+ TSL_SetStatusFromIOError(s, error_code, context);
}
+TF_Code TF_GetCode(const TF_Status* s) { return TF_Code(TSL_GetCode(s)); }
+const char* TF_Message(const TF_Status* s) { return TSL_Message(s); }
diff --git a/tensorflow/c/tf_status.h b/tensorflow/c/tf_status.h
index 4616ee434d9..db1d32bf8e7 100644
--- a/tensorflow/c/tf_status.h
+++ b/tensorflow/c/tf_status.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_STATUS_H_
#define TENSORFLOW_C_TF_STATUS_H_
+#include "tensorflow/tsl/c/tsl_status.h"
+
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
@@ -34,30 +36,29 @@ limitations under the License.
extern "C" {
#endif
-typedef struct TF_Status TF_Status;
+typedef struct TSL_Status TF_Status;
// --------------------------------------------------------------------------
// TF_Code holds an error code. The enum values here are identical to
// corresponding values in error_codes.proto.
-typedef enum TF_Code {
- TF_OK = 0,
- TF_CANCELLED = 1,
- TF_UNKNOWN = 2,
- TF_INVALID_ARGUMENT = 3,
- TF_DEADLINE_EXCEEDED = 4,
- TF_NOT_FOUND = 5,
- TF_ALREADY_EXISTS = 6,
- TF_PERMISSION_DENIED = 7,
- TF_UNAUTHENTICATED = 16,
- TF_RESOURCE_EXHAUSTED = 8,
- TF_FAILED_PRECONDITION = 9,
- TF_ABORTED = 10,
- TF_OUT_OF_RANGE = 11,
- TF_UNIMPLEMENTED = 12,
- TF_INTERNAL = 13,
- TF_UNAVAILABLE = 14,
- TF_DATA_LOSS = 15,
-} TF_Code;
+typedef TSL_Code TF_Code;
+#define TF_OK TSL_OK
+#define TF_CANCELLED TSL_CANCELLED
+#define TF_UNKNOWN TSL_UNKNOWN
+#define TF_INVALID_ARGUMENT TSL_INVALID_ARGUMENT
+#define TF_DEADLINE_EXCEEDED TSL_DEADLINE_EXCEEDED
+#define TF_NOT_FOUND TSL_NOT_FOUND
+#define TF_ALREADY_EXISTS TSL_ALREADY_EXISTS
+#define TF_PERMISSION_DENIED TSL_PERMISSION_DENIED
+#define TF_UNAUTHENTICATED TSL_UNAUTHENTICATED
+#define TF_RESOURCE_EXHAUSTED TSL_RESOURCE_EXHAUSTED
+#define TF_FAILED_PRECONDITION TSL_FAILED_PRECONDITION
+#define TF_ABORTED TSL_ABORTED
+#define TF_OUT_OF_RANGE TSL_OUT_OF_RANGE
+#define TF_UNIMPLEMENTED TSL_UNIMPLEMENTED
+#define TF_INTERNAL TSL_INTERNAL
+#define TF_UNAVAILABLE TSL_UNAVAILABLE
+#define TF_DATA_LOSS TSL_DATA_LOSS
// --------------------------------------------------------------------------
diff --git a/tensorflow/c/tf_status_helper.cc b/tensorflow/c/tf_status_helper.cc
index 1e4360d5531..9155d9dde8b 100644
--- a/tensorflow/c/tf_status_helper.cc
+++ b/tensorflow/c/tf_status_helper.cc
@@ -17,75 +17,16 @@ limitations under the License.
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/tsl/c/tsl_status_helper.h"
namespace tsl {
void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status) {
- tensorflow::error::Code code = status.code();
- const char* message(status.error_message().c_str());
-
- switch (code) {
- case tensorflow::error::OK:
- assert(TF_GetCode(tf_status) == TF_OK);
- break;
- case tensorflow::error::CANCELLED:
- TF_SetStatus(tf_status, TF_CANCELLED, message);
- break;
- case tensorflow::error::UNKNOWN:
- TF_SetStatus(tf_status, TF_UNKNOWN, message);
- break;
- case tensorflow::error::INVALID_ARGUMENT:
- TF_SetStatus(tf_status, TF_INVALID_ARGUMENT, message);
- break;
- case tensorflow::error::DEADLINE_EXCEEDED:
- TF_SetStatus(tf_status, TF_DEADLINE_EXCEEDED, message);
- break;
- case tensorflow::error::NOT_FOUND:
- TF_SetStatus(tf_status, TF_NOT_FOUND, message);
- break;
- case tensorflow::error::ALREADY_EXISTS:
- TF_SetStatus(tf_status, TF_ALREADY_EXISTS, message);
- break;
- case tensorflow::error::PERMISSION_DENIED:
- TF_SetStatus(tf_status, TF_PERMISSION_DENIED, message);
- break;
- case tensorflow::error::UNAUTHENTICATED:
- TF_SetStatus(tf_status, TF_UNAUTHENTICATED, message);
- break;
- case tensorflow::error::RESOURCE_EXHAUSTED:
- TF_SetStatus(tf_status, TF_RESOURCE_EXHAUSTED, message);
- break;
- case tensorflow::error::FAILED_PRECONDITION:
- TF_SetStatus(tf_status, TF_FAILED_PRECONDITION, message);
- break;
- case tensorflow::error::ABORTED:
- TF_SetStatus(tf_status, TF_ABORTED, message);
- break;
- case tensorflow::error::OUT_OF_RANGE:
- TF_SetStatus(tf_status, TF_OUT_OF_RANGE, message);
- break;
- case tensorflow::error::UNIMPLEMENTED:
- TF_SetStatus(tf_status, TF_UNIMPLEMENTED, message);
- break;
- case tensorflow::error::INTERNAL:
- TF_SetStatus(tf_status, TF_INTERNAL, message);
- break;
- case tensorflow::error::UNAVAILABLE:
- TF_SetStatus(tf_status, TF_UNAVAILABLE, message);
- break;
- case tensorflow::error::DATA_LOSS:
- TF_SetStatus(tf_status, TF_DATA_LOSS, message);
- break;
- default:
- assert(0);
- break;
- }
-
- errors::CopyPayloads(status, tf_status->status);
+ Set_TSL_Status_from_Status(tf_status, status);
}
Status StatusFromTF_Status(const TF_Status* tf_status) {
- return tf_status->status;
+ return StatusFromTSL_Status(tf_status);
}
} // namespace tsl
diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h
index 4c3c8af6864..df4600b85dc 100644
--- a/tensorflow/c/tf_status_helper.h
+++ b/tensorflow/c/tf_status_helper.h
@@ -21,10 +21,10 @@ limitations under the License.
namespace tsl {
// Set the attribute of "tf_status" from the attributes of "status".
-void Set_TF_Status_from_Status(TF_Status* tf_status, const tsl::Status& status);
+void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status);
// Returns a "status" from "tf_status".
-tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status);
+Status StatusFromTF_Status(const TF_Status* tf_status);
} // namespace tsl
namespace tensorflow {
diff --git a/tensorflow/c/tf_status_helper_test.cc b/tensorflow/c/tf_status_helper_test.cc
deleted file mode 100644
index 0bd9d1e4e3c..00000000000
--- a/tensorflow/c/tf_status_helper_test.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/c/tf_status_helper.h"
-
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace {
-
-TEST(StatusHelper, TestStatusHelper) {
- TF_Status* s = TF_NewStatus();
- Status cc_status(errors::InvalidArgument("some error"));
- cc_status.SetPayload("key1", "value1");
- cc_status.SetPayload("key2", "value2");
- Set_TF_Status_from_Status(s, cc_status);
- ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
- ASSERT_EQ(std::string("some error"), TF_Message(s));
-
- Status another_cc_status(StatusFromTF_Status(s));
- ASSERT_FALSE(another_cc_status.ok());
- ASSERT_EQ(std::string("some error"), another_cc_status.error_message());
- ASSERT_EQ(error::INVALID_ARGUMENT, another_cc_status.code());
- // Ensure the payloads are not lost during conversions
- ASSERT_EQ(cc_status.GetPayload("key1"), another_cc_status.GetPayload("key1"));
- ASSERT_EQ(cc_status.GetPayload("key2"), another_cc_status.GetPayload("key2"));
- TF_DeleteStatus(s);
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/c/tf_status_internal.h b/tensorflow/c/tf_status_internal.h
index 1e0f99819ff..7a40d6f518e 100644
--- a/tensorflow/c/tf_status_internal.h
+++ b/tensorflow/c/tf_status_internal.h
@@ -16,13 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_STATUS_INTERNAL_H_
#define TENSORFLOW_C_TF_STATUS_INTERNAL_H_
-#include "tensorflow/core/platform/status.h"
+#include "tensorflow/tsl/c/tsl_status_internal.h"
-// Internal structures used by the status C API. These are likely to change
-// and should not be depended on.
-
-struct TF_Status {
- tensorflow::Status status;
-};
+typedef struct TSL_Status TF_Status;
#endif // TENSORFLOW_C_TF_STATUS_INTERNAL_H_
diff --git a/tensorflow/c/tf_status_test.cc b/tensorflow/c/tf_status_test.cc
deleted file mode 100644
index 50f5dfb0f96..00000000000
--- a/tensorflow/c/tf_status_test.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/c/tf_status.h"
-
-#include
-
-#include "tensorflow/c/tf_status_internal.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace {
-
-TEST(TF_Status, PayloadsSet) {
- TF_Status* tf_status = TF_NewStatus();
- TF_SetStatus(tf_status, TF_CANCELLED, "Error Message");
- TF_SetPayload(tf_status, "a", "1");
- TF_SetPayload(tf_status, "b", "2");
- TF_SetPayload(tf_status, "c", "3");
-
- const std::unordered_map payloads =
- errors::GetPayloads(tf_status->status);
- EXPECT_EQ(payloads.size(), 3);
- EXPECT_EQ(payloads.at("a"), "1");
- EXPECT_EQ(payloads.at("b"), "2");
- EXPECT_EQ(payloads.at("c"), "3");
- TF_DeleteStatus(tf_status);
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc
index 7bf662d81e0..e007af200c4 100644
--- a/tensorflow/c/tf_tensor.cc
+++ b/tensorflow/c/tf_tensor.cc
@@ -247,7 +247,7 @@ Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type,
const int64_t* new_dims, int num_new_dims) {
tensorflow::TensorShape s;
for (int i = 0; i < num_new_dims; ++i) {
- s.AddDim(new_dims[i]);
+ TF_RETURN_IF_ERROR(s.AddDimWithStatus(new_dims[i]));
}
return tensor_.BitcastFrom(from.tensor_, type, s);
}
diff --git a/tensorflow/c/tf_tstring.h b/tensorflow/c/tf_tstring.h
index 5dc29f23d59..f9fb2fe083f 100644
--- a/tensorflow/c/tf_tstring.h
+++ b/tensorflow/c/tf_tstring.h
@@ -59,4 +59,4 @@ TF_CAPI_EXPORT extern void TF_StringDealloc(TF_TString *tstr);
} /* end extern "C" */
#endif
-#endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_
+#endif // TENSORFLOW_C_TF_TSTRING_H_
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 04796a71711..4fc555871af 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -13,6 +13,7 @@ load(
load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_gen_op_wrappers_cc")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)
@@ -759,7 +760,7 @@ tf_gen_op_wrappers_cc(
"function_ops",
],
pkg = "//tensorflow/core",
- visibility = ["//tensorflow:internal"],
+ visibility = ["//visibility:public"],
)
tf_gen_op_wrappers_cc(
diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc
index 27ec4c0871d..3c5357f739e 100644
--- a/tensorflow/cc/client/client_session_test.cc
+++ b/tensorflow/cc/client/client_session_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/client/client_session.h"
+#include
#include
#include "absl/synchronization/barrier.h"
@@ -39,6 +40,14 @@ using ops::Mul;
using ops::Placeholder;
using ops::Sub;
+tensorflow::SessionOptions GetSessionOptions() {
+ tensorflow::SessionOptions options;
+ // Disable optimizations for static graph to allow calls to Session::Extend.
+ options.config.mutable_experimental()->set_disable_optimize_for_static_graph(
+ true);
+ return options;
+}
+
class CustomThreadPoolImpl : public thread::ThreadPoolInterface {
public:
explicit CustomThreadPoolImpl(int numThreads) {
@@ -100,7 +109,7 @@ TEST(ClientSessionTest, Extend) {
Scope root = Scope::NewRootScope();
auto a = Placeholder(root, DT_INT32, Placeholder::Shape({2}));
auto c = Add(root, a, {2, 2});
- ClientSession session(root);
+ ClientSession session(root, GetSessionOptions());
std::vector outputs;
TF_EXPECT_OK(session.Run({{a, {1, 1}}}, {c}, &outputs));
@@ -116,7 +125,7 @@ TEST(ClientSessionTest, MultiThreadedWithDefaultThreadpool) {
Scope root = Scope::NewRootScope();
auto a = Add(root, {1, 2}, {3, 4});
auto b = Mul(root, {1, 2}, {3, 4});
- ClientSession session(root);
+ ClientSession session(root, GetSessionOptions());
{
thread::ThreadPool thread_pool(Env::Default(), "pool", 2);
thread_pool.Schedule([&session, a]() {
@@ -143,7 +152,7 @@ TEST(ClientSessionTest, MultiThreadedWithCustomThreadpool) {
int num_threads = 3;
auto a = Add(root, {1, 2}, {3, 4});
auto b = Mul(root, {1, 2}, {3, 4});
- ClientSession session(root);
+ ClientSession session(root, GetSessionOptions());
auto inter_op_threadpool =
absl::make_unique(num_threads);
diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD
index 5313b502bf5..7c1a040960f 100644
--- a/tensorflow/cc/experimental/base/public/BUILD
+++ b/tensorflow/cc/experimental/base/public/BUILD
@@ -11,6 +11,7 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
# This is intentionally public
default_visibility = [
"//visibility:public",
diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD
index 5f442faa77c..e749d2433bd 100644
--- a/tensorflow/cc/experimental/base/tests/BUILD
+++ b/tensorflow/cc/experimental/base/tests/BUILD
@@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
diff --git a/tensorflow/cc/experimental/libexport/BUILD b/tensorflow/cc/experimental/libexport/BUILD
index 5533cf76431..910ab930440 100644
--- a/tensorflow/cc/experimental/libexport/BUILD
+++ b/tensorflow/cc/experimental/libexport/BUILD
@@ -5,6 +5,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility =
[
"//tensorflow/cc/experimental/libtf:__subpackages__",
diff --git a/tensorflow/cc/experimental/libtf/BUILD b/tensorflow/cc/experimental/libtf/BUILD
index e9529725d94..e281672de9e 100644
--- a/tensorflow/cc/experimental/libtf/BUILD
+++ b/tensorflow/cc/experimental/libtf/BUILD
@@ -12,6 +12,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup")
load("//tensorflow:strict.default.bzl", "py_strict_binary")
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
"//tensorflow/cc/experimental/libtf:__subpackages__",
],
diff --git a/tensorflow/cc/experimental/libtf/impl/BUILD b/tensorflow/cc/experimental/libtf/impl/BUILD
index 8231a25102e..0eae5a1f05c 100644
--- a/tensorflow/cc/experimental/libtf/impl/BUILD
+++ b/tensorflow/cc/experimental/libtf/impl/BUILD
@@ -10,6 +10,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
"//tensorflow/cc/experimental/libtf:__subpackages__",
],
diff --git a/tensorflow/cc/experimental/libtf/mlir/BUILD b/tensorflow/cc/experimental/libtf/mlir/BUILD
index 2d42d855dae..51336186510 100644
--- a/tensorflow/cc/experimental/libtf/mlir/BUILD
+++ b/tensorflow/cc/experimental/libtf/mlir/BUILD
@@ -6,6 +6,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
"//tensorflow/cc/experimental/libtf:__subpackages__",
],
diff --git a/tensorflow/cc/experimental/libtf/object.h b/tensorflow/cc/experimental/libtf/object.h
index 8f510def431..72d05aaf430 100644
--- a/tensorflow/cc/experimental/libtf/object.h
+++ b/tensorflow/cc/experimental/libtf/object.h
@@ -166,7 +166,7 @@ class Object : public Handle {
if (class_dict_maybe.type() == TaggedValue::DICT) {
auto& dict = class_dict_maybe.dict();
auto it = dict.find(key.value_);
- if (it != value_.dict().end()) {
+ if (it != dict.end()) {
return Cast(Handle(it->second));
}
}
diff --git a/tensorflow/cc/experimental/libtf/runtime/BUILD b/tensorflow/cc/experimental/libtf/runtime/BUILD
index 75f81a5a8a2..b20c0e6e3f9 100644
--- a/tensorflow/cc/experimental/libtf/runtime/BUILD
+++ b/tensorflow/cc/experimental/libtf/runtime/BUILD
@@ -4,6 +4,7 @@ load(
)
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
"//tensorflow/cc/experimental/libtf:__subpackages__",
],
diff --git a/tensorflow/cc/experimental/libtf/runtime/core/BUILD b/tensorflow/cc/experimental/libtf/runtime/core/BUILD
index cb750c4c7a4..83f61ee11ba 100644
--- a/tensorflow/cc/experimental/libtf/runtime/core/BUILD
+++ b/tensorflow/cc/experimental/libtf/runtime/core/BUILD
@@ -1,4 +1,5 @@
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
"//tensorflow/cc/experimental/libtf:__subpackages__",
],
diff --git a/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD b/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD
index 6350e007875..586ef6b9523 100644
--- a/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD
+++ b/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD
@@ -1,4 +1,5 @@
package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
"//tensorflow/cc/experimental/libtf:__subpackages__",
],
diff --git a/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc b/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc
index 0be93c31a28..59952002522 100644
--- a/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc
+++ b/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc
@@ -21,7 +21,7 @@ namespace runtime {
INSTANTIATE_TEST_SUITE_P(TF2CAPI, RuntimeTest,
::testing::Values(core::Runtime));
-
+GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(RuntimeTest);
} // namespace runtime
} // namespace libtf
} // namespace tf
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index d0cd220f112..031451d3d2d 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/framework/cc_op_gen.h"
+#include
#include
#include
#include
diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h
index 9af3b9ce1e3..7b348365b33 100644
--- a/tensorflow/cc/framework/cc_op_gen.h
+++ b/tensorflow/cc/framework/cc_op_gen.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
#define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
+#include
+
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/cc/framework/cc_op_gen_util.h b/tensorflow/cc/framework/cc_op_gen_util.h
index d6c729f2dc9..8fb90356841 100644
--- a/tensorflow/cc/framework/cc_op_gen_util.h
+++ b/tensorflow/cc/framework/cc_op_gen_util.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_UTIL_H_
#define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_UTIL_H_
+#include
#include
#include
#include
diff --git a/tensorflow/cc/framework/fuzzing/BUILD b/tensorflow/cc/framework/fuzzing/BUILD
index 4c6b0d80baf..c14b324fdf2 100644
--- a/tensorflow/cc/framework/fuzzing/BUILD
+++ b/tensorflow/cc/framework/fuzzing/BUILD
@@ -7,6 +7,8 @@ load(
)
load("//tensorflow:tensorflow.bzl", "tf_copts")
+# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
+
cc_library(
name = "cc_op_fuzz_gen_main",
srcs = [
@@ -28,6 +30,7 @@ cc_library(
"//tensorflow/core/platform:hash",
"//tensorflow/tsl/platform:status",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
@@ -36,15 +39,8 @@ cc_library(
# tf_gen_op_wrappers_fuzz(
# name = "array_ops_fuzz",
# api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
-# extra_gen_deps = ["//tensorflow/c/kernels:bitcast_op_lib"],
-# op_lib_names = [
-# "array_ops",
-# ],
-# pkg = "//tensorflow/core",
-# deps = [
-# "//third_party/mediapipe/framework/port:parse_text_proto",
+# kernel_deps = [
# "//tensorflow/c/kernels:bitcast_op",
-# "//tensorflow/cc:cc_ops",
# "//tensorflow/core/kernels:array",
# "//tensorflow/core/kernels:check_numerics_op",
# "//tensorflow/core/kernels:fake_quant_ops",
@@ -57,6 +53,7 @@ cc_library(
# "//tensorflow/core/kernels/linalg:matrix_diag_op",
# "//tensorflow/core/kernels/linalg:matrix_set_diag_op",
# ],
+# op_def_src = "//tensorflow/core/ops:array_ops_op_lib",
# )
# copybara:uncomment_end
diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc
index 02af4b4aa86..416bb56e820 100644
--- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc
+++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc
@@ -202,11 +202,58 @@ string WriteFuzzTest(const OpInfo& op_info) {
}));
}
+string FuzzerFileStart() {
+ const string fuzz_namespace_begin = R"namespace(
+namespace tensorflow {
+namespace fuzzing {
+
+)namespace";
+
+ const string fuzz_header = strings::StrCat(
+ R"include(// This file is MACHINE GENERATED! Do not edit.
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/security/fuzzing/cc/fuzz_session.h"
+#include "third_party/mediapipe/framework/port/parse_text_proto.h"
+)include",
+ fuzz_namespace_begin);
+
+ return fuzz_header;
+}
+
+string FuzzerFileEnd() {
+ const string fuzz_footer = R"footer(
+} // namespace fuzzing
+} // namespace tensorflow
+)footer";
+
+ return fuzz_footer;
+}
+
+} // namespace
+
bool OpFuzzingIsOk(const OpInfo& op_info) {
+ // Skip deprecated ops.
+ if (op_info.graph_op_def.has_deprecation() &&
+ op_info.graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
+ std::cout << "NOT fuzzing: " << op_info.graph_op_def.name()
+ << " is deprecated.\n";
+ return false;
+ }
+
// TODO(unda, b/249347507): should we hide fuzzers for hidden ops?
- if (op_info.api_def.visibility() == ApiDef::HIDDEN) return false;
+ if (op_info.api_def.visibility() == ApiDef::HIDDEN) {
+ std::cout << "NOT fuzzing: " << op_info.graph_op_def.name()
+ << " is hidden.\n";
+ return false;
+ }
- if (op_info.api_def.visibility() == ApiDef::SKIP) return false;
+ if (op_info.api_def.visibility() == ApiDef::SKIP) {
+ std::cout << "NOT fuzzing: " << op_info.graph_op_def.name()
+ << " is skipped.\n";
+ return false;
+ }
// TODO(unda) : zero input ops
std::set zero_input_ops = {"Placeholder", "ImmutableConst"};
@@ -272,56 +319,10 @@ bool OpFuzzingIsOk(const OpInfo& op_info) {
return true;
}
-string FuzzerFileStart() {
- const string fuzz_namespace_begin = R"namespace(
-namespace tensorflow {
-namespace fuzzing {
-
-)namespace";
-
- const string fuzz_header = strings::StrCat(
- R"include(// This file is MACHINE GENERATED! Do not edit.
-
-#include "tensorflow/cc/ops/const_op.h"
-#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/security/fuzzing/cc/fuzz_session.h"
-#include "third_party/mediapipe/framework/port/parse_text_proto.h"
-)include",
- fuzz_namespace_begin);
-
- return fuzz_header;
-}
-
-string FuzzerFileEnd() {
- const string fuzz_footer = R"footer(
-} // namespace fuzzing
-} // namespace tensorflow
-)footer";
-
- return fuzz_footer;
-}
-
-} // namespace
-
-string WriteFuzzers(const OpList& ops, const ApiDefMap& api_def_map) {
+string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable) {
return absl::StrCat(
- FuzzerFileStart(),
- absl::StrJoin(
- ops.op(), "",
- [&api_def_map](string* out, const OpDef& op_def) {
- // Skip deprecated ops.
- bool skip = op_def.has_deprecation() &&
- op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION;
- const auto* api_Def = api_def_map.GetApiDef(op_def.name());
- OpInfo op_info(op_def, *api_Def, std::vector());
- skip |= !OpFuzzingIsOk(op_info);
- if (!skip) {
- out->append(WriteClassFuzzDef(op_info));
- out->append(WriteFuzzTest(op_info));
- out->append("\n");
- }
- }),
- FuzzerFileEnd());
+ FuzzerFileStart(), is_fuzzable ? WriteClassFuzzDef(op_info) : "",
+ is_fuzzable ? WriteFuzzTest(op_info) : "", FuzzerFileEnd());
}
} // namespace cc_op
diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h
index 6770430ad69..c11c9635d6d 100644
--- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h
+++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CC_FRAMEWORK_FUZZING_CC_OP_FUZZ_GEN_H_
#define TENSORFLOW_CC_FRAMEWORK_FUZZING_CC_OP_FUZZ_GEN_H_
+#include "tensorflow/cc/framework/cc_op_gen_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/types.h"
@@ -23,8 +24,11 @@ limitations under the License.
namespace tensorflow {
namespace cc_op {
-/// String with fuzzer file contents.
-string WriteFuzzers(const OpList& ops, const ApiDefMap& api_def_map);
+// String with single fuzzer file content.
+string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable);
+
+// Do we have all we need to create a fuzzer
+bool OpFuzzingIsOk(const OpInfo& op_info);
} // namespace cc_op
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc
index 0a1de103d37..99388eb8847 100644
--- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc
+++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc
@@ -14,10 +14,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include
#include
#include
#include
+#include "absl/status/status.h"
#include "tensorflow/cc/framework/cc_op_gen_util.h"
#include "tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h"
#include "tensorflow/core/framework/op_def.pb.h"
@@ -33,20 +35,33 @@ namespace tensorflow {
namespace cc_op {
namespace {
-void WriteAllFuzzers(const std::string& file_name, bool include_internal,
- const std::vector& api_def_dirs) {
+void WriteAllFuzzers(string root_location, std::vector api_def_dirs,
+ std::vector op_names) {
OpList ops;
- StatusOr api_def_map =
- LoadOpsAndApiDefs(ops, include_internal, api_def_dirs);
+ StatusOr api_def_map = LoadOpsAndApiDefs(ops, false, api_def_dirs);
TF_CHECK_OK(api_def_map.status());
- WriteFuzzers(ops, api_def_map.value());
Env* env = Env::Default();
+ tsl::Status status;
std::unique_ptr fuzz_file = nullptr;
- auto status = env->NewWritableFile(file_name, &fuzz_file);
- status.Update(fuzz_file->Append(WriteFuzzers(ops, api_def_map.value())));
- status.Update(fuzz_file->Close());
+ for (const OpDef& op_def : ops.op()) {
+ if (std::find(op_names.begin(), op_names.end(), op_def.name()) ==
+ op_names.end())
+ continue;
+
+ const ApiDef* api_def = api_def_map->GetApiDef(op_def.name());
+ if (api_def == nullptr) {
+ continue;
+ }
+
+ OpInfo op_info(op_def, *api_def, std::vector());
+ status.Update(env->NewWritableFile(
+ root_location + "/" + op_def.name() + "_fuzz.cc", &fuzz_file));
+ status.Update(
+ fuzz_file->Append(WriteSingleFuzzer(op_info, OpFuzzingIsOk(op_info))));
+ status.Update(fuzz_file->Close());
+ }
TF_CHECK_OK(status);
}
@@ -60,17 +75,17 @@ int main(int argc, char* argv[]) {
for (int i = 1; i < argc; ++i) {
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
}
- fprintf(stderr,
- "Usage: %s out include_internal "
- "api_def_dirs1,api_def_dir2 ...\n"
- " include_internal: 1 means include internal ops\n",
+ fprintf(stderr, "Usage: %s location api_def1,api_def2 op1,op2,op3\n",
argv[0]);
exit(1);
}
-
- bool include_internal = tensorflow::StringPiece("1") == argv[2];
- std::vector api_def_dirs = tensorflow::str_util::Split(
+ for (int i = 1; i < argc; ++i) {
+ fprintf(stdout, "Arg %d = %s\n", i, argv[i]);
+ }
+ std::vector api_def_srcs = tensorflow::str_util::Split(
+ argv[2], ",", tensorflow::str_util::SkipEmpty());
+ std::vector op_names = tensorflow::str_util::Split(
argv[3], ",", tensorflow::str_util::SkipEmpty());
- tensorflow::cc_op::WriteAllFuzzers(argv[1], include_internal, api_def_dirs);
+ tensorflow::cc_op::WriteAllFuzzers(argv[1], api_def_srcs, op_names);
return 0;
}
diff --git a/tensorflow/cc/framework/fuzzing/op_fuzzing.bzl b/tensorflow/cc/framework/fuzzing/op_fuzzing.bzl
index aac616f8928..2dfb4d08589 100644
--- a/tensorflow/cc/framework/fuzzing/op_fuzzing.bzl
+++ b/tensorflow/cc/framework/fuzzing/op_fuzzing.bzl
@@ -12,108 +12,160 @@ load(
"cc_test",
)
-def tf_gen_op_wrapper_fuzz(
+def tf_gen_op_wrappers_fuzz(
name,
- out_ops_file,
- pkg = "",
- deps = None,
- include_internal_ops = 0,
- api_def_srcs = []):
+ op_def_src,
+ api_def_srcs = [],
+ kernel_deps = []):
"""
- Generates a file with fuzzers for a subset of ops.
+ Generates fuzzers for several groups of ops.
+
+ For each one we need the corresponding OpDef, ApiDef and KernelDef,
+ since they all can contain constraints for the inputs.
Args:
- name: name of the op class
- out_ops_file: prefix for file generation
- pkg: where to find op registrations
- deps: depedencies
- include_internal_ops: true if we should generate internal ops
- api_def_srcs: which op definitions to use
+ name: the name of the fuzz artifact
+ op_def_src: op definitions
+ api_def_srcs: api definitions
+ kernel_deps: op kernel dependencies
"""
- tool = out_ops_file + "_gen_fuzz"
- if deps == None:
- deps = [pkg + ":" + name + "_op_lib"]
+ # Create tool to generate .cc fuzzer files.
tf_cc_binary(
- name = tool,
+ name = "op_fuzz_gen_tool",
copts = tf_copts(),
linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + lrt_if_needed(),
linkstatic = 1, # Faster to link this one-time-use binary dynamically
deps = [
"//tensorflow/cc/framework/fuzzing:cc_op_fuzz_gen_main",
- ] + deps,
+ op_def_src,
+ ] + kernel_deps,
)
- srcs = api_def_srcs[:]
+ # Add relevant locations to look for api_defs.
+ api_def_src_locations = ",".join(["$$(dirname $$(echo $(locations " + api_def_src + ") | cut -d\" \" -f1))" for api_def_src in api_def_srcs])
- if not api_def_srcs:
- api_def_args_str = ","
- else:
- api_def_args = []
- for api_def_src in api_def_srcs:
- # Add directory of the first ApiDef source to args.
- # We are assuming all ApiDefs in a single api_def_src are in the
- # same directory.
- api_def_args.append(
- " $$(dirname $$(echo $(locations " + api_def_src +
- ") | cut -d\" \" -f1))",
- )
- api_def_args_str = ",".join(api_def_args)
-
- out_fuzz_file = out_ops_file + "_fuzz.cc"
+ out_fuzz_files = [op_name + "_fuzz.cc" for op_name in op_names]
native.genrule(
name = name + "_genrule",
- outs = [
- out_fuzz_file,
- ],
- srcs = srcs,
- tools = [":" + tool], # + tf_binary_additional_srcs(),
- cmd = ("$(location :" + tool + ") $(location :" + out_fuzz_file + ") " +
- str(include_internal_ops) + " " + api_def_args_str),
+ outs = out_fuzz_files,
+ srcs = api_def_srcs,
+ tools = [":op_fuzz_gen_tool"],
+ cmd = ("$(location :op_fuzz_gen_tool) " +
+ " $$(dirname $(location " + out_fuzz_files[0] + "))" +
+ " " + api_def_src_locations + " " + (",".join(op_names))),
)
-def tf_gen_op_wrappers_fuzz(
- name,
- op_lib_names = [],
- pkg = "",
- deps = [
- "//tensorflow/cc:ops",
- "//tensorflow/cc:scope",
- "//tensorflow/cc:const_op",
- ],
- include_internal_ops = 0,
- api_def_srcs = [],
- extra_gen_deps = []):
- """
- Generates fuzzers for several groups of ops.
-
- Args:
- name: the name of the fuzz artifact
- op_lib_names: which op libraries to fuzz
- pkg: where to find op registrations
- deps: dependencies
- include_internal_ops: true if we should generate internal ops
- api_def_srcs: where to find the op definitions
- extra_gen_deps: extra dependencies for generation
- """
- fuzzsrcs = []
- for n in op_lib_names:
- tf_gen_op_wrapper_fuzz(
- n,
- "fuzzers/" + n,
- api_def_srcs = api_def_srcs,
- include_internal_ops = include_internal_ops,
- pkg = pkg,
- deps = [pkg + ":" + n + "_op_lib"] + extra_gen_deps,
+ for op_name in op_names:
+ cc_test(
+ name = op_name.lower() + "_fuzz",
+ srcs = [op_name + "_fuzz.cc"],
+ deps = kernel_deps +
+ [
+ "//tensorflow/security/fuzzing/cc:fuzz_session",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_fuzztest//fuzztest",
+ "//tensorflow/cc:cc_ops",
+ "//third_party/mediapipe/framework/port:parse_text_proto",
+ ],
)
- fuzzsrcs.append("fuzzers/" + n + "_fuzz.cc")
- cc_test(
- name = name,
- srcs = fuzzsrcs,
- deps = deps +
- [
- "//tensorflow/security/fuzzing/cc:fuzz_session",
- "@com_google_googletest//:gtest_main",
- "@com_google_fuzztest//fuzztest",
- ],
- )
+
+op_names = [
+ "BatchMatrixBandPart",
+ "BatchMatrixDiag",
+ "BatchMatrixDiagPart",
+ "BatchMatrixSetDiag",
+ "BatchToSpace",
+ "BatchToSpaceND",
+ "Bitcast",
+ "BroadcastArgs",
+ "BroadcastTo",
+ "CheckNumerics",
+ "ConcatV2",
+ "ConjugateTranspose",
+ "DebugGradientIdentity",
+ "DeepCopy",
+ "DepthToSpace",
+ "Dequantize",
+ "EditDistance",
+ "Empty",
+ "EnsureShape",
+ "ExpandDims",
+ "ExtractImagePatches",
+ "ExtractVolumePatches",
+ "FakeQuantWithMinMaxArgs",
+ "FakeQuantWithMinMaxArgsGradient",
+ "FakeQuantWithMinMaxVars",
+ "FakeQuantWithMinMaxVarsGradient",
+ "FakeQuantWithMinMaxVarsPerChannel",
+ "FakeQuantWithMinMaxVarsPerChannelGradient",
+ "Fill",
+ "Fingerprint",
+ "Gather",
+ "GuaranteeConst",
+ "Identity",
+ "IdentityN",
+ "InplaceAdd",
+ "InplaceSub",
+ "InplaceUpdate",
+ "InvertPermutation",
+ "ListDiff",
+ "MatrixBandPart",
+ "MatrixDiag",
+ "MatrixDiagPart",
+ "MatrixDiagPartV2",
+ "MatrixDiagPartV3",
+ "MatrixDiagV2",
+ "MatrixDiagV3",
+ "MatrixSetDiag",
+ "MatrixSetDiagV2",
+ "MatrixSetDiagV3",
+ "MirrorPad",
+ "OneHot",
+ "OnesLike",
+ "Pack",
+ "Pad",
+ "PadV2",
+ "ParallelConcat",
+ "PlaceholderV2",
+ "PlaceholderWithDefault",
+ "PreventGradient",
+ "QuantizeAndDequantize",
+ "QuantizeV2",
+ "Rank",
+ "Reshape",
+ "ResourceStridedSliceAssign",
+ "ReverseSequence",
+ "ReverseV2",
+ "ScatterNdNonAliasingAdd",
+ "Shape",
+ "ShapeN",
+ "Size",
+ "Slice",
+ "Snapshot",
+ "SpaceToBatch",
+ "SpaceToBatchND",
+ "SpaceToDepth",
+ "Split",
+ "SplitV",
+ "Squeeze",
+ "StopGradient",
+ "StridedSlice",
+ "StridedSliceGrad",
+ "TensorScatterAdd",
+ "TensorScatterMax",
+ "TensorScatterMin",
+ "TensorScatterSub",
+ "TensorStridedSliceUpdate",
+ "Tile",
+ "TileGrad",
+ "Transpose",
+ "Unique",
+ "UniqueV2",
+ "UniqueWithCounts",
+ "UniqueWithCountsV2",
+ "Unpack",
+ "UnravelIndex",
+ "Where",
+ "ZerosLike",
+]
diff --git a/tensorflow/cc/framework/grad_op_registry.h b/tensorflow/cc/framework/grad_op_registry.h
index 0fc5abb20c8..951144cf8ce 100644
--- a/tensorflow/cc/framework/grad_op_registry.h
+++ b/tensorflow/cc/framework/grad_op_registry.h
@@ -16,7 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
#define TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
+#include
#include
+#include
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc
index 0013ea732df..0c026cf9a0c 100644
--- a/tensorflow/cc/framework/gradient_checker.cc
+++ b/tensorflow/cc/framework/gradient_checker.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/cc/framework/gradient_checker.h"
+#include
+#include
+
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/standard_ops.h"
diff --git a/tensorflow/cc/framework/gradient_checker.h b/tensorflow/cc/framework/gradient_checker.h
index 1aa215a9088..b8db767f77c 100644
--- a/tensorflow/cc/framework/gradient_checker.h
+++ b/tensorflow/cc/framework/gradient_checker.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
#define TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
+#include
+
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc
index cdb3d0c7d68..3dd2ab3ab82 100644
--- a/tensorflow/cc/framework/gradients.cc
+++ b/tensorflow/cc/framework/gradients.cc
@@ -16,6 +16,11 @@ limitations under the License.
#include "tensorflow/cc/framework/gradients.h"
#include
+#include