diff --git a/.asf.yaml b/.asf.yaml index 50886f2cea5a..a6449ffb8b5f 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -49,6 +49,7 @@ github: protected_branches: master: {} + release-2.62.0: {} release-2.61.0: {} release-2.60.0: {} release-2.59.0: {} diff --git a/.github/trigger_files/IO_Iceberg_Integration_Tests.json b/.github/trigger_files/IO_Iceberg_Integration_Tests.json index 3f63c0c9975f..bbdc3a3910ef 100644 --- a/.github/trigger_files/IO_Iceberg_Integration_Tests.json +++ b/.github/trigger_files/IO_Iceberg_Integration_Tests.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 2 + "modification": 3 } diff --git a/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json b/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json index 920c8d132e4a..8784d0786c02 100644 --- a/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json +++ b/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Java_IO_Performance_Tests.json b/.github/trigger_files/beam_PostCommit_Java_IO_Performance_Tests.json new file mode 100644 index 000000000000..b26833333238 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_IO_Performance_Tests.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json index f1ba03a243ee..455144f02a35 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 5 + "modification": 6 } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json index 9b023f630c36..dd2bf3aeb361 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json @@ -2,5 +2,7 @@ "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test", - "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test" + "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test", + "https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test", + "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json index 9b023f630c36..74f4220571e5 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json @@ -2,5 +2,6 @@ "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test", - "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test" + "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test", + "https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json index 9b023f630c36..dd2bf3aeb361 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json @@ -2,5 +2,7 @@ "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test", - "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test" + "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test", + "https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test", + "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 00bd9e035648..dd3d3e011a0c 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 6 + "modification": 8 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json index b26833333238..e3d6056a5de9 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 2 + "modification": 1 } diff --git a/.github/trigger_files/beam_PreCommit_Java_HBase_IO_Direct.json b/.github/trigger_files/beam_PreCommit_Java_HBase_IO_Direct.json new file mode 100644 index 000000000000..0967ef424bce --- /dev/null +++ b/.github/trigger_files/beam_PreCommit_Java_HBase_IO_Direct.json @@ -0,0 +1 @@ +{} diff --git a/.github/workflows/IO_Iceberg_Integration_Tests.yml b/.github/workflows/IO_Iceberg_Integration_Tests.yml index 68a72790006f..5a9e04968c8a 100644 --- a/.github/workflows/IO_Iceberg_Integration_Tests.yml +++ b/.github/workflows/IO_Iceberg_Integration_Tests.yml @@ -75,4 +75,4 @@ jobs: - name: Run IcebergIO Integration Test uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :sdks:java:io:iceberg:catalogTests --info \ No newline at end of file + gradle-command: :sdks:java:io:iceberg:integrationTest --info \ No newline at end of file diff --git a/.github/workflows/IO_Iceberg_Unit_Tests.yml b/.github/workflows/IO_Iceberg_Unit_Tests.yml index 0d72b0da8597..d063f6ac71db 100644 --- a/.github/workflows/IO_Iceberg_Unit_Tests.yml +++ b/.github/workflows/IO_Iceberg_Unit_Tests.yml @@ -111,6 +111,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 206364f416f7..b9069c530e53 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -15,6 +15,27 @@ under the License. --> +# How to fix Workflows for Committers + +The following is guidance on how to practically make changes that fix workflows. + +1) Create a branch in https://github.com/apache/beam not your fork. + + The reason to perform changes to a branch of the main repo instead of your fork is due to the challenge in replicating the environment within which Beam GitHub workflows execute. GitHub workflows allow you to execute against a branch of a repo. + +2) Make changes in this branch you anticipate will fix the failing workflow. + +3) Run the workflow designating your branch. + + In the GitHub workflow interface, you can designate any branch of the repository to run the workflow against. Selecting your branch allows you to test the changes you made. The following screenshot shows an example of this feature. + ![image](https://github.com/user-attachments/assets/33ca43fb-b0f8-42c8-80e2-ac84a49e2490) + +5) Create a PR, pasting the link to your successful workflow run in the branch + + When doing a PR, the checks will not run against your branch. Your reviewer may not know this so you'll want to mention this in your PR description, pasting the link to your successful run. + +6) After PR merges, execute the workflow manually to validate your merged changes. + # Running Workflows Manually Most workflows will get kicked off automatically when you open a PR, push code, or on a schedule. @@ -207,7 +228,6 @@ PreCommit Jobs run in a schedule and also get triggered in a PR if relevant sour | [ PreCommit Go ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Go.yml) | N/A |`Run Go PreCommit`| [![.github/workflows/beam_PreCommit_Go.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Go.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Go.yml?query=event%3Aschedule) | | [ PreCommit GoPortable ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_GoPortable.yml) | N/A |`Run GoPortable PreCommit`| [![.github/workflows/beam_PreCommit_GoPortable.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_GoPortable.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_GoPortable.yml?query=event%3Aschedule) | | [ PreCommit Java ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java.yml) | N/A |`Run Java PreCommit`| [![.github/workflows/beam_PreCommit_Java.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java.yml?query=event%3Aschedule) | -| [ PreCommit Java Amazon Web Services IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml) | N/A |`Run Java_Amazon-Web-Services_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Amazon Web Services2 IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml) | N/A |`Run Java_Amazon-Web-Services2_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Amqp IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amqp_IO_Direct.yml) | N/A |`Run Java_Amqp_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Amqp_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amqp_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Amqp_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Azure IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml) | N/A |`Run Java_Azure_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml?query=event%3Aschedule) | @@ -231,7 +251,6 @@ PreCommit Jobs run in a schedule and also get triggered in a PR if relevant sour | [ PreCommit Java IOs Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_IOs_Direct.yml) | N/A |`Run Java_IOs_Direct PreCommit`| N/A | | [ PreCommit Java JDBC IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml) | N/A |`Run Java_JDBC_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Jms IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml) | N/A |`Run Java_Jms_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml?query=event%3Aschedule) | -| [ PreCommit Java Kinesis IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml) | N/A |`Run Java_Kinesis_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Kudu IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml) | N/A |`Run Java_Kudu_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java MongoDb IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml) | N/A |`Run Java_MongoDb_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Mqtt IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml) | N/A |`Run Java_Mqtt_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml?query=event%3Aschedule) | diff --git a/.github/workflows/beam_LoadTests_Java_CoGBK_Dataflow_Streaming.yml b/.github/workflows/beam_LoadTests_Java_CoGBK_Dataflow_Streaming.yml index 2b631d2f7664..659c85b002df 100644 --- a/.github/workflows/beam_LoadTests_Java_CoGBK_Dataflow_Streaming.yml +++ b/.github/workflows/beam_LoadTests_Java_CoGBK_Dataflow_Streaming.yml @@ -124,4 +124,5 @@ jobs: uses: EnricoMi/publish-unit-test-result-action@v2 if: always() with: - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Avro.yml b/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Avro.yml index af0569f4784a..74932079fe4c 100644 --- a/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Avro.yml +++ b/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Avro.yml @@ -102,4 +102,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Json.yml b/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Json.yml index 9e3962e2576e..05e5369a6384 100644 --- a/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Json.yml +++ b/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Json.yml @@ -102,4 +102,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_BigQueryIO_Streaming_Java.yml b/.github/workflows/beam_PerformanceTests_BigQueryIO_Streaming_Java.yml index 7514bd5cacb3..32db2cff6cbc 100644 --- a/.github/workflows/beam_PerformanceTests_BigQueryIO_Streaming_Java.yml +++ b/.github/workflows/beam_PerformanceTests_BigQueryIO_Streaming_Java.yml @@ -102,4 +102,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_SQLBigQueryIO_Batch_Java.yml b/.github/workflows/beam_PerformanceTests_SQLBigQueryIO_Batch_Java.yml index 6ac07a1bd76c..d04a6e63c800 100644 --- a/.github/workflows/beam_PerformanceTests_SQLBigQueryIO_Batch_Java.yml +++ b/.github/workflows/beam_PerformanceTests_SQLBigQueryIO_Batch_Java.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_WordCountIT_PythonVersions.yml b/.github/workflows/beam_PerformanceTests_WordCountIT_PythonVersions.yml index e9ef9cd1716a..756ecb5a58c2 100644 --- a/.github/workflows/beam_PerformanceTests_WordCountIT_PythonVersions.yml +++ b/.github/workflows/beam_PerformanceTests_WordCountIT_PythonVersions.yml @@ -115,4 +115,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java.yml b/.github/workflows/beam_PostCommit_Java.yml index 3428551cb8f9..4fafa3b2a993 100644 --- a/.github/workflows/beam_PostCommit_Java.yml +++ b/.github/workflows/beam_PostCommit_Java.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Avro_Versions.yml b/.github/workflows/beam_PostCommit_Java_Avro_Versions.yml index e3a9db23ed67..8ffcc4a28a71 100644 --- a/.github/workflows/beam_PostCommit_Java_Avro_Versions.yml +++ b/.github/workflows/beam_PostCommit_Java_Avro_Versions.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml b/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml index 1a6f7c14db50..8707b515e10b 100644 --- a/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml +++ b/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml @@ -110,3 +110,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_DataflowV1.yml b/.github/workflows/beam_PostCommit_Java_DataflowV1.yml index e7c2aa6fe7e2..752b15936b5f 100644 --- a/.github/workflows/beam_PostCommit_Java_DataflowV1.yml +++ b/.github/workflows/beam_PostCommit_Java_DataflowV1.yml @@ -94,4 +94,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_DataflowV2.yml b/.github/workflows/beam_PostCommit_Java_DataflowV2.yml index 3c0a46d6bb40..cb107572b621 100644 --- a/.github/workflows/beam_PostCommit_Java_DataflowV2.yml +++ b/.github/workflows/beam_PostCommit_Java_DataflowV2.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow.yml index 469d7e31f173..81725c4005af 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow.yml @@ -89,4 +89,5 @@ jobs: uses: EnricoMi/publish-unit-test-result-action@v2 if: always() with: - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_ARM.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_ARM.yml index 9fd84daef63b..eacdfe5a5c23 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_ARM.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_ARM.yml @@ -119,3 +119,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_Java.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_Java.yml index 13ab05f8f173..efb926681cbf 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_Java.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_Java.yml @@ -97,4 +97,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2.yml index 9be8a34f3732..1882cdf1d76b 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2_Java.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2_Java.yml index cd2486ae8e10..05b28ac93658 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2_Java.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2_Java.yml @@ -104,4 +104,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Direct.yml b/.github/workflows/beam_PostCommit_Java_Examples_Direct.yml index ca06e72877c7..a746acb4333f 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Direct.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Direct.yml @@ -92,4 +92,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml b/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml index e42d6a88b8df..f72910bd15bc 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml @@ -94,3 +94,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Spark.yml b/.github/workflows/beam_PostCommit_Java_Examples_Spark.yml index 8008daf4584f..c3620e46fac9 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Spark.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Spark.yml @@ -92,4 +92,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Hadoop_Versions.yml b/.github/workflows/beam_PostCommit_Java_Hadoop_Versions.yml index 67a48b105955..1202ecc0e27f 100644 --- a/.github/workflows/beam_PostCommit_Java_Hadoop_Versions.yml +++ b/.github/workflows/beam_PostCommit_Java_Hadoop_Versions.yml @@ -100,4 +100,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_IO_Performance_Tests.yml b/.github/workflows/beam_PostCommit_Java_IO_Performance_Tests.yml index a6a2749c8d82..6023a895a458 100644 --- a/.github/workflows/beam_PostCommit_Java_IO_Performance_Tests.yml +++ b/.github/workflows/beam_PostCommit_Java_IO_Performance_Tests.yml @@ -88,11 +88,6 @@ jobs: uses: ./.github/actions/setup-environment-action with: java-version: default - - name: Authenticate on GCP - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: run scheduled javaPostcommitIOPerformanceTests script if: github.event_name == 'schedule' #This ensures only scheduled runs publish metrics publicly by changing which exportTable is configured uses: ./.github/actions/gradle-command-self-hosted-action @@ -122,3 +117,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java11.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java11.yml index 37f784770477..323f85b9851a 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java11.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java11.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java17.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java17.yml index 377602ad08dd..1ccb26f5aa1f 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java17.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java17.yml @@ -96,4 +96,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java11.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java11.yml index 80406cf4eb0c..02ac93135957 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java11.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java11.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java17.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java17.yml index 3cbc317317c2..2cbf60a48d2e 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java17.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java17.yml @@ -96,4 +96,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java21.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java21.yml index 97fd1fb4913e..6a7058ef566d 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java21.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java21.yml @@ -97,4 +97,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Flink_Java11.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Flink_Java11.yml index 1a7405836f69..1559061634d3 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Flink_Java11.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Flink_Java11.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Spark_Java11.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Spark_Java11.yml index eec4867a997b..1b4f8c5bcce5 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Spark_Java11.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Spark_Java11.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml b/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml index 987be7789b29..8c5fcb1acff4 100644 --- a/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml @@ -91,3 +91,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_PVR_Samza.yml b/.github/workflows/beam_PostCommit_Java_PVR_Samza.yml index 7cc48ebd4b0e..c1a22b9c871d 100644 --- a/.github/workflows/beam_PostCommit_Java_PVR_Samza.yml +++ b/.github/workflows/beam_PostCommit_Java_PVR_Samza.yml @@ -100,4 +100,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml b/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml index ad10bfc684d8..76ab560f15ec 100644 --- a/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml b/.github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml index 4fb236c7c991..73fd6f0b78fa 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml @@ -113,3 +113,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow.yml index d66381393725..c85c0b8468dc 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_JavaVersions.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_JavaVersions.yml index da2ba2f88465..5963a33007e0 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_JavaVersions.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_JavaVersions.yml @@ -111,4 +111,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.yml index edb055321c87..2e8227fb84a6 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.yml index 8957ce7de053..2abc081e6ae5 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.yml index 2a98746a0b84..fde10e0898e9 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct.yml index 3f48bb921805..f439be9ec58e 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct_JavaVersions.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct_JavaVersions.yml index 75ebbda93f80..eb70a654c93d 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct_JavaVersions.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct_JavaVersions.yml @@ -106,4 +106,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Samza.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Samza.yml index 794308d3a85e..edcb45303fd4 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Samza.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Samza.yml @@ -96,4 +96,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml index d1f264aaac01..d05963263931 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml index 15863d4c8c9b..da04582a7caa 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml index c05284186617..8d531c120dd6 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml @@ -108,4 +108,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml index 522cb300c687..8310e5ed8bb2 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml index 36fc06aea421..3b130b6d290f 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml @@ -89,4 +89,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_PortableJar_Flink.yml b/.github/workflows/beam_PostCommit_PortableJar_Flink.yml index 37bfe68d9b20..318b5104c39c 100644 --- a/.github/workflows/beam_PostCommit_PortableJar_Flink.yml +++ b/.github/workflows/beam_PostCommit_PortableJar_Flink.yml @@ -94,4 +94,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_PortableJar_Spark.yml b/.github/workflows/beam_PostCommit_PortableJar_Spark.yml index ce7be60133d7..0712dfb255b7 100644 --- a/.github/workflows/beam_PostCommit_PortableJar_Spark.yml +++ b/.github/workflows/beam_PostCommit_PortableJar_Spark.yml @@ -94,4 +94,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python.yml b/.github/workflows/beam_PostCommit_Python.yml index 4770515c75fb..93b85a318487 100644 --- a/.github/workflows/beam_PostCommit_Python.yml +++ b/.github/workflows/beam_PostCommit_Python.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Arm.yml b/.github/workflows/beam_PostCommit_Python_Arm.yml index 48fb00b1bb9d..352b95e6747a 100644 --- a/.github/workflows/beam_PostCommit_Python_Arm.yml +++ b/.github/workflows/beam_PostCommit_Python_Arm.yml @@ -124,4 +124,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Dependency.yml b/.github/workflows/beam_PostCommit_Python_Dependency.yml index 6e7c4ddbd3eb..80e1bbc290c9 100644 --- a/.github/workflows/beam_PostCommit_Python_Dependency.yml +++ b/.github/workflows/beam_PostCommit_Python_Dependency.yml @@ -96,3 +96,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Python_Examples_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_Examples_Dataflow.yml index 4ce3b1893215..bf8330a2ae58 100644 --- a/.github/workflows/beam_PostCommit_Python_Examples_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_Examples_Dataflow.yml @@ -94,4 +94,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Examples_Direct.yml b/.github/workflows/beam_PostCommit_Python_Examples_Direct.yml index a6bb49f4e444..e271b7da9a7b 100644 --- a/.github/workflows/beam_PostCommit_Python_Examples_Direct.yml +++ b/.github/workflows/beam_PostCommit_Python_Examples_Direct.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Examples_Flink.yml b/.github/workflows/beam_PostCommit_Python_Examples_Flink.yml index f23674a2c70a..28fd13c181b3 100644 --- a/.github/workflows/beam_PostCommit_Python_Examples_Flink.yml +++ b/.github/workflows/beam_PostCommit_Python_Examples_Flink.yml @@ -102,3 +102,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Python_Examples_Spark.yml b/.github/workflows/beam_PostCommit_Python_Examples_Spark.yml index d866d412507b..5df6bcf8c01c 100644 --- a/.github/workflows/beam_PostCommit_Python_Examples_Spark.yml +++ b/.github/workflows/beam_PostCommit_Python_Examples_Spark.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_MongoDBIO_IT.yml b/.github/workflows/beam_PostCommit_Python_MongoDBIO_IT.yml index 578775a9d3ed..0d334b679dc5 100644 --- a/.github/workflows/beam_PostCommit_Python_MongoDBIO_IT.yml +++ b/.github/workflows/beam_PostCommit_Python_MongoDBIO_IT.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow.yml index bcd936324124..6e16f43476b2 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow.yml @@ -108,3 +108,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow_With_RC.yml b/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow_With_RC.yml index f2eba045722c..3ab7257f8a9d 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow_With_RC.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow_With_RC.yml @@ -106,4 +106,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml index 6f8a7bdd0631..c294dd3c9068 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml @@ -118,3 +118,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Dataflow.yml index 1876950c7a93..f8daa1a96634 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Dataflow.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Flink.yml b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Flink.yml index f837c7476e12..9277bd68fc01 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Flink.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Flink.yml @@ -103,4 +103,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Samza.yml b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Samza.yml index 91c249adf338..e058724cd2ac 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Samza.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Samza.yml @@ -102,4 +102,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Spark.yml b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Spark.yml index 7e87aaff22cc..a47f758ed410 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Spark.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Spark.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml index b3f37c6b39f0..bd266cf6fdab 100644 --- a/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml b/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml index 137d7bc13d2f..6d26d1c46012 100644 --- a/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml +++ b/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml @@ -92,4 +92,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml index 8fc0db189078..08e99fa0fe0f 100644 --- a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml @@ -95,4 +95,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml index 5092a1981154..a7643c795af4 100644 --- a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml +++ b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_SQL.yml b/.github/workflows/beam_PostCommit_SQL.yml index c7d0b6dc98b9..aebea2b0564b 100644 --- a/.github/workflows/beam_PostCommit_SQL.yml +++ b/.github/workflows/beam_PostCommit_SQL.yml @@ -91,3 +91,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_TransformService_Direct.yml b/.github/workflows/beam_PostCommit_TransformService_Direct.yml index cb339eb9fb40..d0d72f3df13c 100644 --- a/.github/workflows/beam_PostCommit_TransformService_Direct.yml +++ b/.github/workflows/beam_PostCommit_TransformService_Direct.yml @@ -98,4 +98,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_Direct.yml b/.github/workflows/beam_PostCommit_XVR_Direct.yml index 023ae4f8cd31..af8b7fb1bf54 100644 --- a/.github/workflows/beam_PostCommit_XVR_Direct.yml +++ b/.github/workflows/beam_PostCommit_XVR_Direct.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_Flink.yml b/.github/workflows/beam_PostCommit_XVR_Flink.yml index 1f1d7d863b7e..fe4404247448 100644 --- a/.github/workflows/beam_PostCommit_XVR_Flink.yml +++ b/.github/workflows/beam_PostCommit_XVR_Flink.yml @@ -111,3 +111,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_XVR_GoUsingJava_Dataflow.yml b/.github/workflows/beam_PostCommit_XVR_GoUsingJava_Dataflow.yml index 228f10b90cd0..0620023ce7d2 100644 --- a/.github/workflows/beam_PostCommit_XVR_GoUsingJava_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_XVR_GoUsingJava_Dataflow.yml @@ -102,3 +102,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_XVR_JavaUsingPython_Dataflow.yml b/.github/workflows/beam_PostCommit_XVR_JavaUsingPython_Dataflow.yml index 66770c9a1683..11a8a5c5f4f7 100644 --- a/.github/workflows/beam_PostCommit_XVR_JavaUsingPython_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_XVR_JavaUsingPython_Dataflow.yml @@ -95,4 +95,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.yml b/.github/workflows/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.yml index bfb602f89daf..c393a4113589 100644 --- a/.github/workflows/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.yml @@ -92,4 +92,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_PythonUsingJava_Dataflow.yml b/.github/workflows/beam_PostCommit_XVR_PythonUsingJava_Dataflow.yml index f1269a0ddd09..082aeb3f2ab2 100644 --- a/.github/workflows/beam_PostCommit_XVR_PythonUsingJava_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_XVR_PythonUsingJava_Dataflow.yml @@ -95,4 +95,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_Samza.yml b/.github/workflows/beam_PostCommit_XVR_Samza.yml index 2d26c9131839..7e2dca61d41d 100644 --- a/.github/workflows/beam_PostCommit_XVR_Samza.yml +++ b/.github/workflows/beam_PostCommit_XVR_Samza.yml @@ -111,4 +111,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_Spark3.yml b/.github/workflows/beam_PostCommit_XVR_Spark3.yml index c1880e01292b..17fb58d9dd73 100644 --- a/.github/workflows/beam_PostCommit_XVR_Spark3.yml +++ b/.github/workflows/beam_PostCommit_XVR_Spark3.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_ItFramework.yml b/.github/workflows/beam_PreCommit_ItFramework.yml index e078d4645757..e803fc023c67 100644 --- a/.github/workflows/beam_PreCommit_ItFramework.yml +++ b/.github/workflows/beam_PreCommit_ItFramework.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java.yml b/.github/workflows/beam_PreCommit_Java.yml index 20dafca72a57..bc25fb94f8f0 100644 --- a/.github/workflows/beam_PreCommit_Java.yml +++ b/.github/workflows/beam_PreCommit_Java.yml @@ -198,6 +198,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml index ecbc85ca1b1d..cf0d0b660782 100644 --- a/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml @@ -130,6 +130,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml deleted file mode 100644 index 55935251e6d9..000000000000 --- a/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml +++ /dev/null @@ -1,144 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: PreCommit Java Amazon-Web-Services IO Direct - -on: - push: - tags: ['v*'] - branches: ['master', 'release-*'] - paths: - - "sdks/java/io/amazon-web-services/**" - - "sdks/java/io/common/**" - - "sdks/java/core/src/main/**" - - "build.gradle" - - "buildSrc/**" - - "gradle/**" - - "gradle.properties" - - "gradlew" - - "gradle.bat" - - "settings.gradle.kts" - - ".github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml" - pull_request_target: - branches: ['master', 'release-*'] - paths: - - "sdks/java/io/amazon-web-services/**" - - "sdks/java/io/common/**" - - "sdks/java/core/src/main/**" - - 'release/trigger_all_tests.json' - - '.github/trigger_files/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.json' - - "build.gradle" - - "buildSrc/**" - - "gradle/**" - - "gradle.properties" - - "gradlew" - - "gradle.bat" - - "settings.gradle.kts" - issue_comment: - types: [created] - schedule: - - cron: '0 1/6 * * *' - workflow_dispatch: - -#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event -permissions: - actions: write - pull-requests: write - checks: write - contents: read - deployments: read - id-token: none - issues: write - discussions: read - packages: read - pages: read - repository-projects: read - security-events: read - statuses: read - -# This allows a subsequently queued workflow run to interrupt previous runs -concurrency: - group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' - cancel-in-progress: true - -env: - DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} - GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} - GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} - -jobs: - beam_PreCommit_Java_Amazon-Web-Services_IO_Direct: - name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - strategy: - matrix: - job_name: ["beam_PreCommit_Java_Amazon-Web-Services_IO_Direct"] - job_phrase: ["Run Java_Amazon-Web-Services_IO_Direct PreCommit"] - timeout-minutes: 60 - if: | - github.event_name == 'push' || - github.event_name == 'pull_request_target' || - (github.event_name == 'schedule' && github.repository == 'apache/beam') || - github.event_name == 'workflow_dispatch' || - github.event.comment.body == 'Run Java_Amazon-Web-Services_IO_Direct PreCommit' - runs-on: [self-hosted, ubuntu-20.04, main] - steps: - - uses: actions/checkout@v4 - - name: Setup repository - uses: ./.github/actions/setup-action - with: - comment_phrase: ${{ matrix.job_phrase }} - github_token: ${{ secrets.GITHUB_TOKEN }} - github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - - name: Setup environment - uses: ./.github/actions/setup-environment-action - - name: run Amazon-Web-Services IO build script - uses: ./.github/actions/gradle-command-self-hosted-action - with: - gradle-command: :sdks:java:io:amazon-web-services:build - arguments: | - -PdisableSpotlessCheck=true \ - -PdisableCheckStyle=true \ - - name: run Amazon-Web-Services IO IT script - uses: ./.github/actions/gradle-command-self-hosted-action - with: - gradle-command: :sdks:java:io:amazon-web-services:integrationTest - arguments: | - -PdisableSpotlessCheck=true \ - -PdisableCheckStyle=true \ - - name: Archive JUnit Test Results - uses: actions/upload-artifact@v4 - if: ${{ !success() }} - with: - name: JUnit Test Results - path: "**/build/reports/tests/" - - name: Publish JUnit Test Results - uses: EnricoMi/publish-unit-test-result-action@v2 - if: always() - with: - commit: '${{ env.prsha || env.GITHUB_SHA }}' - comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' - - name: Archive SpotBugs Results - uses: actions/upload-artifact@v4 - if: always() - with: - name: SpotBugs Results - path: '**/build/reports/spotbugs/*.html' - - name: Publish SpotBugs Results - uses: jwgmeligmeyling/spotbugs-github-action@v1.2 - if: always() - with: - name: Publish SpotBugs - path: '**/build/reports/spotbugs/*.html' \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml index 4fbacecde4a4..8c0bb07e1acb 100644 --- a/.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml @@ -123,6 +123,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Cassandra_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Cassandra_IO_Direct.yml index e37bc5c56e2e..317b2e1f2ec1 100644 --- a/.github/workflows/beam_PreCommit_Java_Cassandra_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Cassandra_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Cdap_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Cdap_IO_Direct.yml index 68ebe3c28fb3..3e0208b758cc 100644 --- a/.github/workflows/beam_PreCommit_Java_Cdap_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Cdap_IO_Direct.yml @@ -109,6 +109,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Clickhouse_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Clickhouse_IO_Direct.yml index 5c0b169b0ba1..2be7607b5bc7 100644 --- a/.github/workflows/beam_PreCommit_Java_Clickhouse_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Clickhouse_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Csv_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Csv_IO_Direct.yml index ce91551c1121..6901e56c0bbb 100644 --- a/.github/workflows/beam_PreCommit_Java_Csv_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Csv_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Debezium_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Debezium_IO_Direct.yml index b6a0e6b999bd..6f32c3844b1a 100644 --- a/.github/workflows/beam_PreCommit_Java_Debezium_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Debezium_IO_Direct.yml @@ -114,6 +114,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_ElasticSearch_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_ElasticSearch_IO_Direct.yml index 78ab882d4774..11a95cf476c7 100644 --- a/.github/workflows/beam_PreCommit_Java_ElasticSearch_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_ElasticSearch_IO_Direct.yml @@ -117,6 +117,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml index 4bfb20a28e7c..8e22318bdcb9 100644 --- a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml +++ b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml @@ -117,4 +117,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow_Java21.yml b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow_Java21.yml index 72fc945018f6..763de153b137 100644 --- a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow_Java21.yml +++ b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow_Java21.yml @@ -133,6 +133,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/beam_PreCommit_Java_File-schema-transform_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_File-schema-transform_IO_Direct.yml index e96dc7c883bf..e121fe1e53a2 100644 --- a/.github/workflows/beam_PreCommit_Java_File-schema-transform_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_File-schema-transform_IO_Direct.yml @@ -106,6 +106,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Flink_Versions.yml b/.github/workflows/beam_PreCommit_Java_Flink_Versions.yml index 19b0d56a8051..09bf906e5a38 100644 --- a/.github/workflows/beam_PreCommit_Java_Flink_Versions.yml +++ b/.github/workflows/beam_PreCommit_Java_Flink_Versions.yml @@ -104,4 +104,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml index 2256d0a91cb8..ee5bea3d3ab3 100644 --- a/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml @@ -127,6 +127,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Google-ads_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Google-ads_IO_Direct.yml index c481251bef03..0e6bd11e7f1e 100644 --- a/.github/workflows/beam_PreCommit_Java_Google-ads_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Google-ads_IO_Direct.yml @@ -103,6 +103,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_HBase_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_HBase_IO_Direct.yml index 3b99e30bfbac..c334edd7f32d 100644 --- a/.github/workflows/beam_PreCommit_Java_HBase_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_HBase_IO_Direct.yml @@ -107,6 +107,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_HCatalog_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_HCatalog_IO_Direct.yml index 6d45ba82aa49..ed079c1e9dd1 100644 --- a/.github/workflows/beam_PreCommit_Java_HCatalog_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_HCatalog_IO_Direct.yml @@ -122,6 +122,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Hadoop_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Hadoop_IO_Direct.yml index c2beaa3c1099..442085586a3c 100644 --- a/.github/workflows/beam_PreCommit_Java_Hadoop_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Hadoop_IO_Direct.yml @@ -145,6 +145,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_IOs_Direct.yml b/.github/workflows/beam_PreCommit_Java_IOs_Direct.yml index 4e19a56dde0c..cd73d402c7ea 100644 --- a/.github/workflows/beam_PreCommit_Java_IOs_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_IOs_Direct.yml @@ -122,6 +122,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_InfluxDb_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_InfluxDb_IO_Direct.yml index 903a7cd73526..977781de506f 100644 --- a/.github/workflows/beam_PreCommit_Java_InfluxDb_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_InfluxDb_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml index 071cdb3bda3e..4759d48d979f 100644 --- a/.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml @@ -112,6 +112,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml index 650036345274..935315463358 100644 --- a/.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml @@ -112,6 +112,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Kafka_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Kafka_IO_Direct.yml index 0ede01376ce7..f177ec85fada 100644 --- a/.github/workflows/beam_PreCommit_Java_Kafka_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Kafka_IO_Direct.yml @@ -114,6 +114,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml deleted file mode 100644 index 494a738abf45..000000000000 --- a/.github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml +++ /dev/null @@ -1,151 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: PreCommit Java Kinesis IO Direct - -on: - push: - tags: ['v*'] - branches: ['master', 'release-*'] - paths: - - "sdks/java/io/kinesis/**" - - "sdks/java/io/common/**" - - "sdks/java/core/src/main/**" - - "build.gradle" - - "buildSrc/**" - - "gradle/**" - - "gradle.properties" - - "gradlew" - - "gradle.bat" - - "settings.gradle.kts" - - ".github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml" - pull_request_target: - branches: ['master', 'release-*'] - paths: - - "sdks/java/io/kinesis/**" - - "sdks/java/io/common/**" - - "sdks/java/core/src/main/**" - - "build.gradle" - - "buildSrc/**" - - "gradle/**" - - "gradle.properties" - - "gradlew" - - "gradle.bat" - - "settings.gradle.kts" - - 'release/trigger_all_tests.json' - - '.github/trigger_files/beam_PreCommit_Java_Kinesis_IO_Direct.json' - issue_comment: - types: [created] - schedule: - - cron: '0 2/6 * * *' - workflow_dispatch: - -#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event -permissions: - actions: write - pull-requests: write - checks: write - contents: read - deployments: read - id-token: none - issues: write - discussions: read - packages: read - pages: read - repository-projects: read - security-events: read - statuses: read - -# This allows a subsequently queued workflow run to interrupt previous runs -concurrency: - group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' - cancel-in-progress: true - -env: - DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} - GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} - GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} - -jobs: - beam_PreCommit_Java_Kinesis_IO_Direct: - name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - strategy: - matrix: - job_name: ["beam_PreCommit_Java_Kinesis_IO_Direct"] - job_phrase: ["Run Java_Kinesis_IO_Direct PreCommit"] - timeout-minutes: 60 - if: | - github.event_name == 'push' || - github.event_name == 'pull_request_target' || - (github.event_name == 'schedule' && github.repository == 'apache/beam') || - github.event_name == 'workflow_dispatch' || - github.event.comment.body == 'Run Java_Kinesis_IO_Direct PreCommit' - runs-on: [self-hosted, ubuntu-20.04, main] - steps: - - uses: actions/checkout@v4 - - name: Setup repository - uses: ./.github/actions/setup-action - with: - comment_phrase: ${{ matrix.job_phrase }} - github_token: ${{ secrets.GITHUB_TOKEN }} - github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - - name: Setup environment - uses: ./.github/actions/setup-environment-action - - name: run Kinesis IO build script - uses: ./.github/actions/gradle-command-self-hosted-action - with: - gradle-command: :sdks:java:io:kinesis:build - arguments: | - -PdisableSpotlessCheck=true \ - -PdisableCheckStyle=true \ - - name: run Kinesis expansion service script - uses: ./.github/actions/gradle-command-self-hosted-action - with: - gradle-command: :sdks:java:io:kinesis:expansion-service:build - arguments: | - -PdisableSpotlessCheck=true \ - -PdisableCheckStyle=true \ - - name: run Kinesis IO IT script - uses: ./.github/actions/gradle-command-self-hosted-action - with: - gradle-command: :sdks:java:io:kinesis:integrationTest - arguments: | - -PdisableSpotlessCheck=true \ - -PdisableCheckStyle=true \ - - name: Archive JUnit Test Results - uses: actions/upload-artifact@v4 - if: ${{ !success() }} - with: - name: JUnit Test Results - path: "**/build/reports/tests/" - - name: Publish JUnit Test Results - uses: EnricoMi/publish-unit-test-result-action@v2 - if: always() - with: - commit: '${{ env.prsha || env.GITHUB_SHA }}' - comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' - - name: Archive SpotBugs Results - uses: actions/upload-artifact@v4 - if: always() - with: - name: SpotBugs Results - path: '**/build/reports/spotbugs/*.html' - - name: Publish SpotBugs Results - uses: jwgmeligmeyling/spotbugs-github-action@v1.2 - if: always() - with: - name: Publish SpotBugs - path: '**/build/reports/spotbugs/*.html' \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml index e38c9a761dee..853e52db14db 100644 --- a/.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml index 11be57c05759..b3292ac5f29b 100644 --- a/.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml index ac8800f55cdf..ed0189d8006b 100644 --- a/.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Neo4j_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Neo4j_IO_Direct.yml index 553300f1889c..62429a611f2a 100644 --- a/.github/workflows/beam_PreCommit_Java_Neo4j_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Neo4j_IO_Direct.yml @@ -114,6 +114,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml index 5feb0270c68c..48f165f4e59f 100644 --- a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml +++ b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml @@ -115,3 +115,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PreCommit_Java_Parquet_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Parquet_IO_Direct.yml index 0bec073fc37b..d217f0e88c39 100644 --- a/.github/workflows/beam_PreCommit_Java_Parquet_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Parquet_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Pulsar_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Pulsar_IO_Direct.yml index e25b4ff6fa94..3a9d62fb64c6 100644 --- a/.github/workflows/beam_PreCommit_Java_Pulsar_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Pulsar_IO_Direct.yml @@ -123,6 +123,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_RabbitMq_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_RabbitMq_IO_Direct.yml index eb343f193395..c72b04bc108d 100644 --- a/.github/workflows/beam_PreCommit_Java_RabbitMq_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_RabbitMq_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Redis_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Redis_IO_Direct.yml index 13b9c26b4b81..cd4ddc387ffc 100644 --- a/.github/workflows/beam_PreCommit_Java_Redis_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Redis_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_RequestResponse_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_RequestResponse_IO_Direct.yml index f1e8c3699aa6..1037ab972447 100644 --- a/.github/workflows/beam_PreCommit_Java_RequestResponse_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_RequestResponse_IO_Direct.yml @@ -103,6 +103,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_SingleStore_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_SingleStore_IO_Direct.yml index 4d289882353e..478dad9989b9 100644 --- a/.github/workflows/beam_PreCommit_Java_SingleStore_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_SingleStore_IO_Direct.yml @@ -107,6 +107,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml index 03577eff2860..403c26ac0ab0 100644 --- a/.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml @@ -116,6 +116,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml index 5aeaaec11dec..ca05b44875cb 100644 --- a/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml @@ -112,6 +112,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml index e6138a0c10d9..80cd5e492992 100644 --- a/.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml b/.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml index 18f5a6c0c86e..c6b2d7e57128 100644 --- a/.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml +++ b/.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml @@ -112,4 +112,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml index 73a1a0b5cdb2..53f3c4327739 100644 --- a/.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml index 4cddfa728cc1..b5336537c556 100644 --- a/.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml index e08b5048b359..195e9aa1f168 100644 --- a/.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Python.yml b/.github/workflows/beam_PreCommit_Python.yml index fb1c6c80873a..68c69ae953a4 100644 --- a/.github/workflows/beam_PreCommit_Python.yml +++ b/.github/workflows/beam_PreCommit_Python.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Coverage.yml b/.github/workflows/beam_PreCommit_Python_Coverage.yml index 0e295250817d..3c7c3b05d8bc 100644 --- a/.github/workflows/beam_PreCommit_Python_Coverage.yml +++ b/.github/workflows/beam_PreCommit_Python_Coverage.yml @@ -104,4 +104,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Dataframes.yml b/.github/workflows/beam_PreCommit_Python_Dataframes.yml index f045842e061d..ecbb1a30e5f7 100644 --- a/.github/workflows/beam_PreCommit_Python_Dataframes.yml +++ b/.github/workflows/beam_PreCommit_Python_Dataframes.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Examples.yml b/.github/workflows/beam_PreCommit_Python_Examples.yml index 09d46217d6d6..44329f63014d 100644 --- a/.github/workflows/beam_PreCommit_Python_Examples.yml +++ b/.github/workflows/beam_PreCommit_Python_Examples.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Integration.yml b/.github/workflows/beam_PreCommit_Python_Integration.yml index 20aade431f6d..3a709c70f077 100644 --- a/.github/workflows/beam_PreCommit_Python_Integration.yml +++ b/.github/workflows/beam_PreCommit_Python_Integration.yml @@ -116,4 +116,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_ML.yml b/.github/workflows/beam_PreCommit_Python_ML.yml index 714eceef5f6b..3b3a2150ac28 100644 --- a/.github/workflows/beam_PreCommit_Python_ML.yml +++ b/.github/workflows/beam_PreCommit_Python_ML.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_PVR_Flink.yml b/.github/workflows/beam_PreCommit_Python_PVR_Flink.yml index dbc1264fcc04..5dd12d49ccd9 100644 --- a/.github/workflows/beam_PreCommit_Python_PVR_Flink.yml +++ b/.github/workflows/beam_PreCommit_Python_PVR_Flink.yml @@ -125,4 +125,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Runners.yml b/.github/workflows/beam_PreCommit_Python_Runners.yml index 5db6e94be781..689d9b2c3c3f 100644 --- a/.github/workflows/beam_PreCommit_Python_Runners.yml +++ b/.github/workflows/beam_PreCommit_Python_Runners.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Transforms.yml b/.github/workflows/beam_PreCommit_Python_Transforms.yml index 820ca3e26df6..431b82c02fb7 100644 --- a/.github/workflows/beam_PreCommit_Python_Transforms.yml +++ b/.github/workflows/beam_PreCommit_Python_Transforms.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_SQL.yml b/.github/workflows/beam_PreCommit_SQL.yml index b4002fcc2a79..5bc8bb581955 100644 --- a/.github/workflows/beam_PreCommit_SQL.yml +++ b/.github/workflows/beam_PreCommit_SQL.yml @@ -103,6 +103,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_SQL_Java17.yml b/.github/workflows/beam_PreCommit_SQL_Java17.yml index 0e5dcc87d16f..1cfd7502389d 100644 --- a/.github/workflows/beam_PreCommit_SQL_Java17.yml +++ b/.github/workflows/beam_PreCommit_SQL_Java17.yml @@ -110,6 +110,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_SQL_Java8.yml b/.github/workflows/beam_PreCommit_SQL_Java8.yml index 23938821b2e8..6b59739dd72d 100644 --- a/.github/workflows/beam_PreCommit_SQL_Java8.yml +++ b/.github/workflows/beam_PreCommit_SQL_Java8.yml @@ -114,6 +114,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml b/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml index b17913946a7e..b9e310a7a133 100644 --- a/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml +++ b/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml @@ -105,3 +105,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_Publish_Java_SDK_Distroless_Snapshots.yml b/.github/workflows/beam_Publish_Java_SDK_Distroless_Snapshots.yml new file mode 100644 index 000000000000..74a0f9a81d63 --- /dev/null +++ b/.github/workflows/beam_Publish_Java_SDK_Distroless_Snapshots.yml @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Publish Beam Java SDK Distroless Snapshots + +on: + schedule: + - cron: '45 */8 * * *' + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + docker_registry: gcr.io + +jobs: + Java_SDK_Distroless_Snapshots: + if: | + github.event_name == 'workflow_dispatch' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 160 + name: ${{ matrix.job_name }} (${{ matrix.java_version }}) + strategy: + fail-fast: false + matrix: + job_name: ["Java_SDK_Distroless_Snapshots"] + job_phrase: ["N/A"] + java_version: + - "java17" + - "java21" + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.java_version }}) + - name: Find Beam Version + # We extract the Beam version here and tag the containers with it. Version will be in the form "2.xx.y.dev". + # This is needed to run pipelines that use the default environment at HEAD, for example, when a + # pipeline uses an expansion service built from HEAD. + run: | + BEAM_VERSION_LINE=$(cat gradle.properties | grep "sdk_version") + echo "BEAM_VERSION=${BEAM_VERSION_LINE#*sdk_version=}" >> $GITHUB_ENV + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + - name: GCloud Docker credential helper + run: | + gcloud auth configure-docker ${{ env.docker_registry }} + - name: Build and push Java distroless image + run: | + docker buildx build --push \ + -t gcr.io/apache-beam-testing/beam-sdk/beam_${{ matrix.java_version }}_sdk_distroless:${{ github.sha }} \ + -t gcr.io/apache-beam-testing/beam-sdk/beam_${{ matrix.java_version }}_sdk_distroless:${BEAM_VERSION} \ + -t gcr.io/apache-beam-testing/beam-sdk/beam_${{ matrix.java_version }}_sdk_distroless:latest \ + -f sdks/java/container/Dockerfile-distroless \ + --build-arg=BEAM_BASE=gcr.io/apache-beam-testing/beam-sdk/beam_${{ matrix.java_version }}_sdk:${BEAM_VERSION} \ + --build-arg=DISTROLESS_BASE=gcr.io/distroless/${{ matrix.java_version }}-debian12 \ + . diff --git a/.github/workflows/beam_Publish_Python_SDK_Distroless_Snapshots.yml b/.github/workflows/beam_Publish_Python_SDK_Distroless_Snapshots.yml new file mode 100644 index 000000000000..9ae37712044a --- /dev/null +++ b/.github/workflows/beam_Publish_Python_SDK_Distroless_Snapshots.yml @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Publish Beam Python SDK Distroless Snapshots + +on: + schedule: + - cron: '45 */8 * * *' + workflow_dispatch: + + #Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + docker_registry: gcr.io + +jobs: + Python_SDK_Distroless_Snapshots: + if: | + github.event_name == 'workflow_dispatch' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 160 + name: ${{ matrix.job_name }} (${{ matrix.python_version }}) + strategy: + fail-fast: false + matrix: + job_name: ["Python_SDK_Distroless_Snapshots"] + job_phrase: ["N/A"] + python_version: + - "python3.9" + - "python3.10" + - "python3.11" + - "python3.12" + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.python_version }}) + - name: Find Beam Version + # We extract the Beam version here and tag the containers with it. Version will be in the form "2.xx.y.dev". + # This is needed to run pipelines that use the default environment at HEAD, for example, when a + # pipeline uses an expansion service built from HEAD. + run: | + BEAM_VERSION_LINE=$(cat gradle.properties | grep "sdk_version") + echo "BEAM_VERSION=${BEAM_VERSION_LINE#*sdk_version=}" >> $GITHUB_ENV + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + - name: GCloud Docker credential helper + run: | + gcloud auth configure-docker ${{ env.docker_registry }} + # TODO(https://github.com/apache/beam/issues/32914): create after merging into main branch + # - name: Build and push Python distroless image + + + diff --git a/.github/workflows/beam_Python_CostBenchmarks_Dataflow.yml b/.github/workflows/beam_Python_CostBenchmarks_Dataflow.yml new file mode 100644 index 000000000000..209325c429a1 --- /dev/null +++ b/.github/workflows/beam_Python_CostBenchmarks_Dataflow.yml @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Python Cost Benchmarks Dataflow + +on: + schedule: + - cron: '30 18 * * 6' # Run at 6:30 pm UTC on Saturdays + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + INFLUXDB_USER: ${{ secrets.INFLUXDB_USER }} + INFLUXDB_USER_PASSWORD: ${{ secrets.INFLUXDB_USER_PASSWORD }} + +jobs: + beam_Python_Cost_Benchmarks_Dataflow: + if: | + github.event_name == 'workflow_dispatch' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 900 + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + strategy: + matrix: + job_name: ["beam_Python_CostBenchmark_Dataflow"] + job_phrase: ["Run Python Dataflow Cost Benchmarks"] + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup Python environment + uses: ./.github/actions/setup-environment-action + with: + python-version: '3.10' + - name: Prepare test arguments + uses: ./.github/actions/test-arguments-action + with: + test-type: load + test-language: python + argument-file-paths: | + ${{ github.workspace }}/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt + ${{ github.workspace }}/.github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt + # The env variables are created and populated in the test-arguments-action as "_test_arguments_" + - name: get current time + run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV + - name: Run wordcount on Dataflow + uses: ./.github/actions/gradle-command-self-hosted-action + timeout-minutes: 30 + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.benchmarks.wordcount.wordcount \ + -Prunner=DataflowRunner \ + -PpythonVersion=3.10 \ + '-PloadTest.args=${{ env.beam_Python_Cost_Benchmarks_Dataflow_test_arguments_1 }} --job_name=benchmark-tests-wordcount-python-${{env.NOW_UTC}} --output_file=gs://temp-storage-for-end-to-end-tests/wordcount/result_wordcount-${{env.NOW_UTC}}.txt' \ + - name: Run Tensorflow MNIST Image Classification on Dataflow + uses: ./.github/actions/gradle-command-self-hosted-action + timeout-minutes: 30 + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.benchmarks.inference.tensorflow_mnist_classification_cost_benchmark \ + -Prunner=DataflowRunner \ + -PpythonVersion=3.10 \ + -PloadTest.requirementsTxtFile=apache_beam/ml/inference/tensorflow_tests_requirements.txt \ + '-PloadTest.args=${{ env.beam_Python_Cost_Benchmarks_Dataflow_test_arguments_2 }} --job_name=benchmark-tests-tf-mnist-classification-python-${{env.NOW_UTC}} --input_file=gs://apache-beam-ml/testing/inputs/it_mnist_data.csv --output_file=gs://temp-storage-for-end-to-end-tests/inference/result_tf_mnist-${{env.NOW_UTC}}.txt --model=gs://apache-beam-ml/models/tensorflow/mnist/' \ \ No newline at end of file diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 828a6328c0cd..20706e77d0cd 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -219,6 +219,7 @@ jobs: runs-on: ${{ matrix.os_python.runner }} timeout-minutes: 480 strategy: + fail-fast: false matrix: os_python: [ {"os": "ubuntu-20.04", "runner": [self-hosted, ubuntu-20.04, main], "python": "${{ needs.check_env_variables.outputs.py-versions-full }}", arch: "auto" }, @@ -226,7 +227,7 @@ jobs: # TODO(https://github.com/apache/beam/issues/31114) {"os": "macos-13", "runner": "macos-13", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "auto" }, {"os": "windows-latest", "runner": "windows-latest", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "auto" }, - {"os": "ubuntu-20.04", "runner": [self-hosted, ubuntu-20.04, main], "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "aarch64" } + {"os": "ubuntu-20.04", "runner": "ubuntu-latest", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "aarch64" } ] # Keep in sync (remove asterisks) with PY_VERSIONS_FULL env var above - if changed, change that as well. py_version: ["cp39-", "cp310-", "cp311-", "cp312-"] diff --git a/.github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt b/.github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt new file mode 100644 index 000000000000..01f4460b8c7e --- /dev/null +++ b/.github/workflows/cost-benchmarks-pipeline-options/python_tf_mnist_classification.txt @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--region=us-central1 +--machine_type=n1-standard-2 +--num_workers=1 +--disk_size_gb=50 +--autoscaling_algorithm=NONE +--input_options={} +--staging_location=gs://temp-storage-for-perf-tests/loadtests +--temp_location=gs://temp-storage-for-perf-tests/loadtests +--requirements_file=apache_beam/ml/inference/tensorflow_tests_requirements.txt +--publish_to_big_query=true +--metrics_dataset=beam_run_inference +--metrics_table=tf_mnist_classification +--runner=DataflowRunner \ No newline at end of file diff --git a/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt b/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt new file mode 100644 index 000000000000..424936ddad97 --- /dev/null +++ b/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--region=us-central1 +--machine_type=n1-standard-2 +--num_workers=1 +--disk_size_gb=50 +--autoscaling_algorithm=NONE +--input_options={} +--staging_location=gs://temp-storage-for-perf-tests/loadtests +--temp_location=gs://temp-storage-for-perf-tests/loadtests +--publish_to_big_query=true +--metrics_dataset=beam_run_inference +--metrics_table=python_wordcount +--runner=DataflowRunner \ No newline at end of file diff --git a/.github/workflows/java_tests.yml b/.github/workflows/java_tests.yml index 1d6441b24681..bdc78b88cb97 100644 --- a/.github/workflows/java_tests.yml +++ b/.github/workflows/java_tests.yml @@ -20,11 +20,7 @@ name: Java Tests on: workflow_dispatch: - inputs: - runDataflow: - description: 'Type "true" if you want to run Dataflow tests (GCP variables must be configured, check CI.md)' - default: 'false' - required: false + schedule: - cron: '10 2 * * *' push: @@ -33,8 +29,7 @@ on: pull_request: branches: ['master', 'release-*'] tags: ['v*'] - paths: ['sdks/java/**', 'model/**', 'runners/**', 'examples/java/**', - 'examples/kotlin/**', 'release/**', 'buildSrc/**'] + paths: ['sdks/java/**', 'model/**', 'runners/**', 'examples/java/**', 'examples/kotlin/**', 'release/**', 'buildSrc/**'] # This allows a subsequently queued workflow run to interrupt previous runs concurrency: group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login}}' @@ -44,26 +39,6 @@ env: GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: - check_gcp_variables: - timeout-minutes: 5 - name: "Check GCP variables set" - runs-on: [self-hosted, ubuntu-20.04, main] - outputs: - gcp-variables-set: ${{ steps.check_gcp_variables.outputs.gcp-variables-set }} - steps: - - name: Check out code - uses: actions/checkout@v4 - - name: "Check are GCP variables set" - run: "./scripts/ci/ci_check_are_gcp_variables_set.sh" - id: check_gcp_variables - env: - GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} - GCP_SA_EMAIL: ${{ secrets.GCP_SA_EMAIL }} - GCP_SA_KEY: ${{ secrets.GCP_SA_KEY }} - GCP_TESTING_BUCKET: ${{ secrets.GCP_TESTING_BUCKET }} - GCP_REGION: "not-needed-here" - GCP_PYTHON_WHEELS_BUCKET: "not-needed-here" - java_unit_tests: name: 'Java Unit Tests' runs-on: ${{ matrix.os }} @@ -152,46 +127,3 @@ jobs: with: name: java_wordcount_direct_runner-${{matrix.os}} path: examples/java/build/reports/tests/integrationTest - - java_wordcount_dataflow: - name: 'Java Wordcount Dataflow' - needs: - - check_gcp_variables - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [[self-hosted, ubuntu-20.04, main], windows-latest] - # TODO(https://github.com/apache/beam/issues/31848) run on Dataflow after fixes credential on macOS/win GHA runner - if: | - needs.check_gcp_variables.outputs.gcp-variables-set == 'true' && - (github.event_name == 'workflow_dispatch' && github.event.inputs.runDataflow == 'true') - steps: - - name: Check out code - uses: actions/checkout@v4 - with: - persist-credentials: false - submodules: recursive - - name: Setup environment - uses: ./.github/actions/setup-environment-action - with: - java-version: 11 - go-version: default - - name: Authenticate on GCP - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - - name: Run WordCount - uses: ./.github/actions/gradle-command-self-hosted-action - with: - gradle-command: integrationTest - arguments: -p examples/ --tests org.apache.beam.examples.WordCountIT - -DintegrationTestPipelineOptions=[\"--runner=DataflowRunner\",\"--project=${{ secrets.GCP_PROJECT_ID }}\",\"--tempRoot=gs://${{ secrets.GCP_TESTING_BUCKET }}/tmp/\"] - -DintegrationTestRunner=dataflow - - name: Upload test logs - uses: actions/upload-artifact@v4 - if: always() - with: - name: java_wordcount_dataflow-${{matrix.os}} - path: examples/java/build/reports/tests/integrationTest \ No newline at end of file diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index 3000d1871be3..2c3b39a33c1d 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -30,10 +30,6 @@ on: tags: 'v*' paths: ['sdks/python/**', 'model/**'] workflow_dispatch: - inputs: - runDataflow: - description: 'Type "true" if you want to run Dataflow tests (GCP variables must be configured, check CI.md)' - default: false # This allows a subsequently queued workflow run to interrupt previous runs concurrency: @@ -57,7 +53,6 @@ jobs: GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} GCP_REGION: ${{ secrets.GCP_REGION }} GCP_SA_EMAIL: ${{ secrets.GCP_SA_EMAIL }} - GCP_SA_KEY: ${{ secrets.GCP_SA_KEY }} GCP_TESTING_BUCKET: ${{ secrets.GCP_TESTING_BUCKET }} GCP_PYTHON_WHEELS_BUCKET: "not-needed-here" @@ -65,8 +60,8 @@ jobs: name: 'Build python source distribution' if: | needs.check_gcp_variables.outputs.gcp-variables-set == 'true' && ( - (github.event_name == 'push' || github.event_name == 'schedule') || - (github.event_name == 'workflow_dispatch' && github.event.inputs.runDataflow == 'true') + ((github.event_name == 'push' || github.event_name == 'schedule') || + github.event_name == 'workflow_dispatch') ) needs: - check_gcp_variables @@ -153,50 +148,3 @@ jobs: working-directory: ./sdks/python shell: bash run: python -m apache_beam.examples.wordcount --input MANIFEST.in --output counts - - python_wordcount_dataflow: - name: 'Python Wordcount Dataflow' - # TODO(https://github.com/apache/beam/issues/31848) run on Dataflow after fixes credential on macOS/win GHA runner - if: (github.event_name == 'workflow_dispatch' && github.event.inputs.runDataflow == 'true') - needs: - - build_python_sdk_source - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [[self-hosted, ubuntu-20.04, main], macos-latest, windows-latest] - python: ["3.9", "3.10", "3.11", "3.12"] - steps: - - name: Checkout code - uses: actions/checkout@v4 - - name: Setup environment - uses: ./.github/actions/setup-environment-action - with: - python-version: ${{ matrix.python }} - go-version: default - - name: Download source from artifacts - uses: actions/download-artifact@v4.1.8 - with: - name: python_sdk_source - path: apache-beam-source - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - - name: Install requirements - working-directory: ./sdks/python - run: pip install setuptools --upgrade && pip install -e ".[gcp]" - - name: Run WordCount - working-directory: ./sdks/python - shell: bash - run: | - python -m apache_beam.examples.wordcount \ - --input gs://dataflow-samples/shakespeare/kinglear.txt \ - --output gs://${{ secrets.GCP_TESTING_BUCKET }}/python_wordcount_dataflow/counts \ - --runner DataflowRunner \ - --project ${{ secrets.GCP_PROJECT_ID }} \ - --region ${{ secrets.GCP_REGION }} \ - --temp_location gs://${{ secrets.GCP_TESTING_BUCKET }}/tmp/python_wordcount_dataflow/ \ - --sdk_location ../../apache-beam-source/apache-beam-source.tar.gz diff --git a/.github/workflows/republish_released_docker_containers.yml b/.github/workflows/republish_released_docker_containers.yml index ed6e74ecf13d..5ab38cff10fc 100644 --- a/.github/workflows/republish_released_docker_containers.yml +++ b/.github/workflows/republish_released_docker_containers.yml @@ -24,18 +24,16 @@ on: inputs: RELEASE: description: Beam version of current release (e.g. 2.XX.0) - required: true - default: '' + required: false RC: description: Integer RC version for the release (e.g. 3 for RC3) - required: true - default: '' + required: false schedule: - cron: "0 6 * * 1" env: docker_registry: gcr.io - release: ${{ github.event.inputs.RELEASE || "2.61.0" }} - rc: ${{ github.event.inputs.RC || "3" }} + release: "${{ github.event.inputs.RELEASE || '2.61.0' }}" + rc: "${{ github.event.inputs.RC || '3' }}" jobs: @@ -69,5 +67,13 @@ jobs: run: | gcloud auth configure-docker ${{ env.docker_registry }} - name: Push docker images - run: ./gradlew :pushAllDockerImages -PisRelease -Pdocker-pull-licenses -Pprune-images -Pdocker-repository-root=gcr.io/apache-beam-testing/updated_released_container_images -Pdocker-tag=${{ env.release }}rc${{ env.rc }} --no-daemon --no-parallel + run: | + ./gradlew :pushAllDockerImages \ + -PisRelease \ + -Pdocker-pull-licenses \ + -Pprune-images \ + -Pdocker-repository-root=gcr.io/apache-beam-testing/updated_released_container_images \ + -Pdocker-tag-list=${{ env.release }},${{ github.sha }},$(date +'%Y-%m-%d') \ + --no-daemon \ + --no-parallel diff --git a/.github/workflows/typescript_tests.yml b/.github/workflows/typescript_tests.yml index a4e4c2926f84..a25f4d2de42d 100644 --- a/.github/workflows/typescript_tests.yml +++ b/.github/workflows/typescript_tests.yml @@ -21,6 +21,13 @@ name: TypeScript Tests on: workflow_dispatch: + inputs: + runXlang: + description: 'Type "true" if you want to run xlang tests' + default: false + runDataflow: + description: 'Type "true" if you want to run Dataflow tests' + default: false schedule: - cron: '10 2 * * *' push: @@ -68,6 +75,8 @@ jobs: # if: ${{ matrix.os != 'ubuntu-latest' }} typescript_xlang_tests: name: 'TypeScript xlang Tests' + # TODO(https://github.com/apache/beam/issues/33346): remove manual trigger after fixing referenced issue. + if: (github.event_name == 'workflow_dispatch' && github.event.inputs.runXlang == 'true') runs-on: [self-hosted, ubuntu-20.04] timeout-minutes: 15 strategy: @@ -115,16 +124,16 @@ jobs: GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} GCP_REGION: ${{ secrets.GCP_REGION }} GCP_SA_EMAIL: ${{ secrets.GCP_SA_EMAIL }} - GCP_SA_KEY: ${{ secrets.GCP_SA_KEY }} GCP_TESTING_BUCKET: ${{ secrets.GCP_TESTING_BUCKET }} GCP_PYTHON_WHEELS_BUCKET: "not-needed-here" typescript_dataflow_tests: name: 'TypeScript Dataflow Tests' + # TODO(https://github.com/apache/beam/issues/33346): remove manual trigger after fixing referenced issue. + if: (github.event_name == 'workflow_dispatch' && github.event.inputs.runDataflow == 'true') runs-on: ubuntu-latest needs: - check_gcp_variables - if: needs.check_gcp_variables.outputs.gcp-variables-set == 'true' strategy: fail-fast: false steps: @@ -146,11 +155,6 @@ jobs: run: | pip install 'pandas>=1.0,<1.5' pip install -e ".[gcp]" - - name: Authenticate on GCP - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - run: npm ci working-directory: ./sdks/typescript - run: npm run build diff --git a/.gitignore b/.gitignore index 19778831888d..2bad81975ba0 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,8 @@ sdks/**/vendor/**/* runners/**/vendor/**/* **/.gradletasknamecache **/generated/* +/go.mod +/go.sum # Ignore sources generated into the main tree **/src/main/generated/** diff --git a/CHANGES.md b/CHANGES.md index 1b943a99f8a0..d5cbb76fb3d5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -53,7 +53,7 @@ * ([#X](https://github.com/apache/beam/issues/X)). --> -# [2.62.0] - Unreleased +# [2.63.0] - Unreleased ## Highlights @@ -62,7 +62,6 @@ ## I/Os -* gcs-connector config options can be set via GcsOptions (Java) ([#32769](https://github.com/apache/beam/pull/32769)). * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## New Features / Improvements @@ -71,8 +70,7 @@ ## Breaking Changes -* Upgraded ZetaSQL to 2024.11.1 ([#32902](https://github.com/apache/beam/pull/32902)). Java11+ is now needed if Beam's ZetaSQL component is used. -* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). +* AWS V1 I/Os have been removed (Java). As part of this, x-lang Python Kinesis I/O has been updated to consume the V2 IO and it also no longer supports setting producer_properties ([#33430](https://github.com/apache/beam/issues/33430)). ## Deprecations @@ -81,16 +79,44 @@ ## Bugfixes * Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). -* Fixed EventTimeTimer ordering in Prism. ([#32222](https://github.com/apache/beam/issues/32222)). ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). -* Fixed (CVE-2024-47561)[https://www.cve.org/CVERecord?id=CVE-2024-47561] (Java) by upgrading Avro version to 1.11.4 ## Known Issues * ([#X](https://github.com/apache/beam/issues/X)). +# [2.62.0] - Unreleased + +## I/Os + +* gcs-connector config options can be set via GcsOptions (Java) ([#32769](https://github.com/apache/beam/pull/32769)). +* [Managed Iceberg] Support partitioning by time (year, month, day, hour) for types `date`, `time`, `timestamp`, and `timestamp(tz)` ([#32939](https://github.com/apache/beam/pull/32939)) +* Upgraded the default version of Hadoop dependencies to 3.4.1. Hadoop 2.10.2 is still supported (Java) ([#33011](https://github.com/apache/beam/issues/33011)). +* [BigQueryIO] Create managed BigLake tables dynamically ([#33125](https://github.com/apache/beam/pull/33125)) + +## New Features / Improvements + +* Added support for stateful processing in Spark Runner for streaming pipelines. Timer functionality is not yet supported and will be implemented in a future release ([#33237](https://github.com/apache/beam/issues/33237)). +* The datetime module is now available for use in jinja templatization for yaml. +* Improved batch performance of SparkRunner's GroupByKey ([#20943](https://github.com/apache/beam/pull/20943)). +* Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)). + * This enables initial Java GroupIntoBatches support. +* Support OrderedListState in Prism ([#32929](https://github.com/apache/beam/issues/32929)). + +## Breaking Changes + +* Upgraded ZetaSQL to 2024.11.1 ([#32902](https://github.com/apache/beam/pull/32902)). Java11+ is now needed if Beam's ZetaSQL component is used. + +## Bugfixes + +* Fixed EventTimeTimer ordering in Prism. ([#32222](https://github.com/apache/beam/issues/32222)). + +## Security Fixes + +* Fixed (CVE-2024-47561)[https://www.cve.org/CVERecord?id=CVE-2024-47561] (Java) by upgrading Avro version to 1.11.4 + # [2.61.0] - 2024-11-25 ## Highlights diff --git a/build.gradle.kts b/build.gradle.kts index d96e77a4c78c..0adb29058479 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -647,6 +647,22 @@ tasks.register("checkSetup") { dependsOn(":examples:java:wordCount") } +// if not disabled make spotlessApply dependency of compileJava and compileTestJava +val disableSpotlessCheck: String by project +val isSpotlessDisabled = project.hasProperty("disableSpotlessCheck") && + disableSpotlessCheck == "true" +if (!isSpotlessDisabled) { + subprojects { + afterEvaluate { + tasks.findByName("spotlessApply")?.let { + listOf("compileJava", "compileTestJava").forEach { + t -> tasks.findByName(t)?.let { f -> f.dependsOn("spotlessApply") } + } + } + } + } +} + // Generates external transform config project.tasks.register("generateExternalTransformsConfig") { dependsOn(":sdks:python:generateExternalTransformsConfig") diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 84c7c3ecfd4a..7b791ef9aa8e 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -594,8 +594,7 @@ class BeamModulePlugin implements Plugin { def activemq_version = "5.14.5" def autovalue_version = "1.9" def autoservice_version = "1.0.1" - def aws_java_sdk_version = "1.12.135" - def aws_java_sdk2_version = "2.20.47" + def aws_java_sdk2_version = "2.20.162" def cassandra_driver_version = "3.10.2" def cdap_version = "6.5.1" def checkerframework_version = "3.42.0" @@ -603,18 +602,18 @@ class BeamModulePlugin implements Plugin { def dbcp2_version = "2.9.0" def errorprone_version = "2.10.0" // [bomupgrader] determined by: com.google.api:gax, consistent with: google_cloud_platform_libraries_bom - def gax_version = "2.55.0" + def gax_version = "2.57.0" def google_ads_version = "33.0.0" def google_clients_version = "2.0.0" - def google_cloud_bigdataoss_version = "2.2.16" + def google_cloud_bigdataoss_version = "2.2.26" // [bomupgrader] determined by: com.google.cloud:google-cloud-spanner, consistent with: google_cloud_platform_libraries_bom - def google_cloud_spanner_version = "6.79.0" + def google_cloud_spanner_version = "6.80.1" def google_code_gson_version = "2.10.1" def google_oauth_clients_version = "1.34.1" // [bomupgrader] determined by: io.grpc:grpc-netty, consistent with: google_cloud_platform_libraries_bom def grpc_version = "1.67.1" def guava_version = "33.1.0-jre" - def hadoop_version = "2.10.2" + def hadoop_version = "3.4.1" def hamcrest_version = "2.1" def influxdb_version = "2.19" def httpclient_version = "4.5.13" @@ -631,7 +630,7 @@ class BeamModulePlugin implements Plugin { def postgres_version = "42.2.16" def powermock_version = "2.0.9" // [bomupgrader] determined by: com.google.protobuf:protobuf-java, consistent with: google_cloud_platform_libraries_bom - def protobuf_version = "3.25.5" + def protobuf_version = "4.28.3" def qpid_jms_client_version = "0.61.0" def quickcheck_version = "1.0" def sbe_tool_version = "1.25.1" @@ -671,14 +670,6 @@ class BeamModulePlugin implements Plugin { auto_value_annotations : "com.google.auto.value:auto-value-annotations:$autovalue_version", avro : "org.apache.avro:avro:1.11.4", avro_tests : "org.apache.avro:avro:1.11.3:tests", - aws_java_sdk_cloudwatch : "com.amazonaws:aws-java-sdk-cloudwatch:$aws_java_sdk_version", - aws_java_sdk_core : "com.amazonaws:aws-java-sdk-core:$aws_java_sdk_version", - aws_java_sdk_dynamodb : "com.amazonaws:aws-java-sdk-dynamodb:$aws_java_sdk_version", - aws_java_sdk_kinesis : "com.amazonaws:aws-java-sdk-kinesis:$aws_java_sdk_version", - aws_java_sdk_s3 : "com.amazonaws:aws-java-sdk-s3:$aws_java_sdk_version", - aws_java_sdk_sns : "com.amazonaws:aws-java-sdk-sns:$aws_java_sdk_version", - aws_java_sdk_sqs : "com.amazonaws:aws-java-sdk-sqs:$aws_java_sdk_version", - aws_java_sdk_sts : "com.amazonaws:aws-java-sdk-sts:$aws_java_sdk_version", aws_java_sdk2_apache_client : "software.amazon.awssdk:apache-client:$aws_java_sdk2_version", aws_java_sdk2_netty_client : "software.amazon.awssdk:netty-nio-client:$aws_java_sdk2_version", aws_java_sdk2_auth : "software.amazon.awssdk:auth:$aws_java_sdk2_version", @@ -737,12 +728,12 @@ class BeamModulePlugin implements Plugin { google_api_client_gson : "com.google.api-client:google-api-client-gson:$google_clients_version", google_api_client_java6 : "com.google.api-client:google-api-client-java6:$google_clients_version", google_api_common : "com.google.api:api-common", // google_cloud_platform_libraries_bom sets version - google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20240919-2.0.0", // [bomupgrader] sets version + google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20241013-2.0.0", // [bomupgrader] sets version google_api_services_cloudresourcemanager : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20240310-2.0.0", // [bomupgrader] sets version google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20240817-$google_clients_version", google_api_services_healthcare : "com.google.apis:google-api-services-healthcare:v1-rev20240130-$google_clients_version", google_api_services_pubsub : "com.google.apis:google-api-services-pubsub:v1-rev20220904-$google_clients_version", - google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20240924-2.0.0", // [bomupgrader] sets version + google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20241008-2.0.0", // [bomupgrader] sets version google_auth_library_credentials : "com.google.auth:google-auth-library-credentials", // google_cloud_platform_libraries_bom sets version google_auth_library_oauth2_http : "com.google.auth:google-auth-library-oauth2-http", // google_cloud_platform_libraries_bom sets version google_cloud_bigquery : "com.google.cloud:google-cloud-bigquery", // google_cloud_platform_libraries_bom sets version @@ -754,13 +745,13 @@ class BeamModulePlugin implements Plugin { google_cloud_core_grpc : "com.google.cloud:google-cloud-core-grpc", // google_cloud_platform_libraries_bom sets version google_cloud_datacatalog_v1beta1 : "com.google.cloud:google-cloud-datacatalog", // google_cloud_platform_libraries_bom sets version google_cloud_dataflow_java_proto_library_all: "com.google.cloud.dataflow:google-cloud-dataflow-java-proto-library-all:0.5.160304", - google_cloud_datastore_v1_proto_client : "com.google.cloud.datastore:datastore-v1-proto-client:2.23.0", // [bomupgrader] sets version + google_cloud_datastore_v1_proto_client : "com.google.cloud.datastore:datastore-v1-proto-client:2.24.1", // [bomupgrader] sets version google_cloud_firestore : "com.google.cloud:google-cloud-firestore", // google_cloud_platform_libraries_bom sets version google_cloud_pubsub : "com.google.cloud:google-cloud-pubsub", // google_cloud_platform_libraries_bom sets version google_cloud_pubsublite : "com.google.cloud:google-cloud-pubsublite", // google_cloud_platform_libraries_bom sets version // [bomupgrader] the BOM version is set by scripts/tools/bomupgrader.py. If update manually, also update // libraries-bom version on sdks/java/container/license_scripts/dep_urls_java.yaml - google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:26.49.0", + google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:26.50.0", google_cloud_secret_manager : "com.google.cloud:google-cloud-secretmanager", // google_cloud_platform_libraries_bom sets version google_cloud_spanner : "com.google.cloud:google-cloud-spanner", // google_cloud_platform_libraries_bom sets version google_cloud_spanner_test : "com.google.cloud:google-cloud-spanner:$google_cloud_spanner_version:tests", @@ -795,6 +786,7 @@ class BeamModulePlugin implements Plugin { grpc_xds : "io.grpc:grpc-xds", // google_cloud_platform_libraries_bom sets version guava : "com.google.guava:guava:$guava_version", guava_testlib : "com.google.guava:guava-testlib:$guava_version", + hadoop_auth : "org.apache.hadoop:hadoop-auth:$hadoop_version", hadoop_client : "org.apache.hadoop:hadoop-client:$hadoop_version", hadoop_common : "org.apache.hadoop:hadoop-common:$hadoop_version", hadoop_mapreduce_client_core : "org.apache.hadoop:hadoop-mapreduce-client-core:$hadoop_version", @@ -2199,7 +2191,7 @@ class BeamModulePlugin implements Plugin { /* include dependencies required by AWS S3 */ if (filesystem?.equalsIgnoreCase('s3')) { - testRuntimeOnly it.project(path: ":sdks:java:io:amazon-web-services", configuration: "testRuntimeMigration") + testRuntimeOnly it.project(path: ":sdks:java:io:amazon-web-services2", configuration: "testRuntimeMigration") } } project.task('packageIntegrationTests', type: Jar) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/Repositories.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/Repositories.groovy index 52cbbd15c35b..58ec64a0add3 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/Repositories.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/Repositories.groovy @@ -39,20 +39,25 @@ class Repositories { mavenCentral() mavenLocal() + // For Confluent Kafka dependencies + maven { + url "https://packages.confluent.io/maven/" + content { includeGroup "io.confluent" } + } + // Release staging repository maven { url "https://oss.sonatype.org/content/repositories/staging/" } // Apache nightly snapshots - maven { url "https://repository.apache.org/snapshots" } + maven { + url "https://repository.apache.org/snapshots" + mavenContent { + snapshotsOnly() + } + } // Apache release snapshots maven { url "https://repository.apache.org/content/repositories/releases" } - - // For Confluent Kafka dependencies - maven { - url "https://packages.confluent.io/maven/" - content { includeGroup "io.confluent" } - } } // Apply a plugin which provides the 'updateOfflineRepository' task that creates an offline diff --git a/contributor-docs/code-change-guide.md b/contributor-docs/code-change-guide.md index b4300103454c..5d344aa89a44 100644 --- a/contributor-docs/code-change-guide.md +++ b/contributor-docs/code-change-guide.md @@ -375,6 +375,9 @@ Follow these steps for Maven projects. Maven-Snapshot maven snapshot repository https://repository.apache.org/content/groups/snapshots/ + + false + ``` diff --git a/contributor-docs/discussion-docs/2024.md b/contributor-docs/discussion-docs/2024.md index 124fe8ef9bb7..a5e1202997ed 100644 --- a/contributor-docs/discussion-docs/2024.md +++ b/contributor-docs/discussion-docs/2024.md @@ -44,3 +44,10 @@ limitations under the License. | 27 | Ahmed Abualsaud | [Python Multi-language with SchemaTransforms](https://docs.google.com/document/d/1_embA3pGwoYG7sbHaYzAkg3hNxjTughhFCY8ThcoK_Q) | 2024-08-26 19:53:10 | | 28 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - September 2024](https://s.apache.org/beam-draft-report-2024-09) | 2024-09-11 15:01:55 | | 29 | Jeff Kinard | [Beam YA(ML)^2](https://docs.google.com/document/d/1z9lNlSBfqDVdOP1frJNv_NJoMR1F1VBI29wn788x6IE/) | 2024-09-11 15:01:55 | +| 30 | Ahmed Abualsaud | [Beam Dynamic Destinations Naming](https://docs.google.com/document/d/1IIn4cjF9eYASnjSmVmmAt6ymFnpBxHgBKVPgpnQ12G4) | 2024-09-18 15:01:55 | +| 31 | Claude van der Merwe | [RAG with Apache Beam ](https://docs.google.com/document/d/1j-kujrxHw4R3-oT4pVAwEIqejoCXhFedqZnBUF8AKBQ) | 2024-11-08 16:37:00 | +| 32 | Shunping Huang | [Anomaly Detection with Beam](https://docs.google.com/document/d/1tE8lz9U_vjlNn2H7t-GRrs3vfhQ5UuCgWiHXCRHRPns) | 2024-12-13 10:37:00 | +| 33 | Radek Stankiewicz | [Kerberized Worker Harness](https://docs.google.com/document/d/1T3Py6VZhP-FNQMjiURj38ddZyhWQRa_vDEUEc4f1P5A) | 2024-12-16 07:27:00 | + + + diff --git a/examples/multi-language/python/wordcount_external.py b/examples/multi-language/python/wordcount_external.py index 580c0269d361..7298d81c1b44 100644 --- a/examples/multi-language/python/wordcount_external.py +++ b/examples/multi-language/python/wordcount_external.py @@ -18,8 +18,8 @@ import logging import apache_beam as beam -from apache_beam.io import ReadFromText from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.transforms.external import BeamJarExpansionService from apache_beam.transforms.external_transform_provider import ExternalTransformProvider from apache_beam.typehints.row_type import RowTypeConstraint """A Python multi-language pipeline that counts words using multiple Java SchemaTransforms. @@ -60,39 +60,35 @@ --expansion_service_port """ -# Original Java transform is in ExtractWordsProvider.java EXTRACT_IDENTIFIER = "beam:schematransform:org.apache.beam:extract_words:v1" -# Original Java transform is in JavaCountProvider.java COUNT_IDENTIFIER = "beam:schematransform:org.apache.beam:count:v1" -# Original Java transform is in WriteWordsProvider.java WRITE_IDENTIFIER = "beam:schematransform:org.apache.beam:write_words:v1" def run(input_path, output_path, expansion_service_port, pipeline_args): pipeline_options = PipelineOptions(pipeline_args) - # Discover and get external transforms from this expansion service - provider = ExternalTransformProvider("localhost:" + expansion_service_port) - # Get transforms with identifiers, then use them as you would a regular - # native PTransform - Extract = provider.get_urn(EXTRACT_IDENTIFIER) - Count = provider.get_urn(COUNT_IDENTIFIER) - Write = provider.get_urn(WRITE_IDENTIFIER) - with beam.Pipeline(options=pipeline_options) as p: - lines = p | 'Read' >> ReadFromText(input_path) - - words = (lines - | 'Prepare Rows' >> beam.Map(lambda line: beam.Row(line=line)) - | 'Extract Words' >> Extract()) - word_counts = words | 'Count Words' >> Count() - formatted_words = ( - word_counts - | 'Format Text' >> beam.Map(lambda row: beam.Row(line="%s: %s" % ( - row.word, row.count))).with_output_types( - RowTypeConstraint.from_fields([('line', str)]))) - - formatted_words | 'Write' >> Write(file_path_prefix=output_path) + expansion_service = BeamJarExpansionService( + "examples:multi-language:shadowJar") + if expansion_service_port: + expansion_service = "localhost:" + expansion_service_port + + provider = ExternalTransformProvider(expansion_service) + # Retrieve portable transforms + Extract = provider.get_urn(EXTRACT_IDENTIFIER) + Count = provider.get_urn(COUNT_IDENTIFIER) + Write = provider.get_urn(WRITE_IDENTIFIER) + + _ = (p + | 'Read' >> beam.io.ReadFromText(input_path) + | 'Prepare Rows' >> beam.Map(lambda line: beam.Row(line=line)) + | 'Extract Words' >> Extract(drop=["king", "palace"]) + | 'Count Words' >> Count() + | 'Format Text' >> beam.Map(lambda row: beam.Row(line="%s: %s" % ( + row.word, row.count))).with_output_types( + RowTypeConstraint.from_fields([('line', str)])) + | 'Write' >> Write(file_path_prefix=output_path)) if __name__ == '__main__': @@ -110,8 +106,10 @@ def run(input_path, output_path, expansion_service_port, pipeline_args): help='Output file') parser.add_argument('--expansion_service_port', dest='expansion_service_port', - required=True, - help='Expansion service port') + required=False, + help='Expansion service port. If left empty, the ' + 'existing multi-language examples service will ' + 'be used by default.') known_args, pipeline_args = parser.parse_known_args() run(known_args.input, known_args.output, known_args.expansion_service_port, diff --git a/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/ExtractWordsProvider.java b/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/ExtractWordsProvider.java index 724dbce276fb..b7224ecec6b4 100644 --- a/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/ExtractWordsProvider.java +++ b/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/ExtractWordsProvider.java @@ -21,9 +21,12 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.List; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; @@ -36,7 +39,6 @@ /** Splits a line into separate words and returns each word. */ @AutoService(SchemaTransformProvider.class) public class ExtractWordsProvider extends TypedSchemaTransformProvider { - public static final Schema OUTPUT_SCHEMA = Schema.builder().addStringField("word").build(); @Override public String identifier() { @@ -45,32 +47,60 @@ public String identifier() { @Override protected SchemaTransform from(Configuration configuration) { - return new SchemaTransform() { - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - return PCollectionRowTuple.of( - "output", - input.get("input").apply(ParDo.of(new ExtractWordsFn())).setRowSchema(OUTPUT_SCHEMA)); - } - }; + return new ExtractWordsTransform(configuration); } - static class ExtractWordsFn extends DoFn { - @ProcessElement - public void processElement(@Element Row element, OutputReceiver receiver) { - // Split the line into words. - String line = Preconditions.checkStateNotNull(element.getString("line")); - String[] words = line.split("[^\\p{L}]+", -1); + static class ExtractWordsTransform extends SchemaTransform { + private static final Schema OUTPUT_SCHEMA = Schema.builder().addStringField("word").build(); + private final List drop; - for (String word : words) { - if (!word.isEmpty()) { - receiver.output(Row.withSchema(OUTPUT_SCHEMA).withFieldValue("word", word).build()); - } - } + ExtractWordsTransform(Configuration configuration) { + this.drop = configuration.getDrop(); + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + return PCollectionRowTuple.of( + "output", + input + .getSinglePCollection() + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void process(@Element Row element, OutputReceiver receiver) { + // Split the line into words. + String line = Preconditions.checkStateNotNull(element.getString("line")); + String[] words = line.split("[^\\p{L}]+", -1); + Arrays.stream(words) + .filter(w -> !drop.contains(w)) + .forEach( + word -> + receiver.output( + Row.withSchema(OUTPUT_SCHEMA) + .withFieldValue("word", word) + .build())); + } + })) + .setRowSchema(OUTPUT_SCHEMA)); } } @DefaultSchema(AutoValueSchema.class) @AutoValue - protected abstract static class Configuration {} + public abstract static class Configuration { + public static Builder builder() { + return new AutoValue_ExtractWordsProvider_Configuration.Builder(); + } + + @SchemaFieldDescription("List of words to drop.") + public abstract List getDrop(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setDrop(List foo); + + public abstract Configuration build(); + } + } } diff --git a/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/JavaCountProvider.java b/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/JavaCountProvider.java index cabea594ae18..90d02d92c3cb 100644 --- a/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/JavaCountProvider.java +++ b/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/JavaCountProvider.java @@ -44,35 +44,37 @@ public String identifier() { @Override protected SchemaTransform from(Configuration configuration) { - return new SchemaTransform() { - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - Schema outputSchema = - Schema.builder().addStringField("word").addInt64Field("count").build(); + return new JavaCountTransform(); + } + + static class JavaCountTransform extends SchemaTransform { + static final Schema OUTPUT_SCHEMA = + Schema.builder().addStringField("word").addInt64Field("count").build(); - PCollection wordCounts = - input - .get("input") - .apply(Count.perElement()) - .apply( - MapElements.into(TypeDescriptors.rows()) - .via( - kv -> - Row.withSchema(outputSchema) - .withFieldValue( - "word", - Preconditions.checkStateNotNull( - kv.getKey().getString("word"))) - .withFieldValue("count", kv.getValue()) - .build())) - .setRowSchema(outputSchema); + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PCollection wordCounts = + input + .get("input") + .apply(Count.perElement()) + .apply( + MapElements.into(TypeDescriptors.rows()) + .via( + kv -> + Row.withSchema(OUTPUT_SCHEMA) + .withFieldValue( + "word", + Preconditions.checkStateNotNull( + kv.getKey().getString("word"))) + .withFieldValue("count", kv.getValue()) + .build())) + .setRowSchema(OUTPUT_SCHEMA); - return PCollectionRowTuple.of("output", wordCounts); - } - }; + return PCollectionRowTuple.of("output", wordCounts); + } } @DefaultSchema(AutoValueSchema.class) @AutoValue - protected abstract static class Configuration {} + public abstract static class Configuration {} } diff --git a/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/WriteWordsProvider.java b/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/WriteWordsProvider.java index 0b2017c5587a..faf9590a7f16 100644 --- a/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/WriteWordsProvider.java +++ b/examples/multi-language/src/main/java/org/apache/beam/examples/multilanguage/schematransforms/WriteWordsProvider.java @@ -42,24 +42,32 @@ public String identifier() { @Override protected SchemaTransform from(Configuration configuration) { - return new SchemaTransform() { - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - input - .get("input") - .apply( - MapElements.into(TypeDescriptors.strings()) - .via(row -> Preconditions.checkStateNotNull(row.getString("line")))) - .apply(TextIO.write().to(configuration.getFilePathPrefix())); + return new WriteWordsTransform(configuration); + } + + static class WriteWordsTransform extends SchemaTransform { + private final String filePathPrefix; + + WriteWordsTransform(Configuration configuration) { + this.filePathPrefix = configuration.getFilePathPrefix(); + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + input + .get("input") + .apply( + MapElements.into(TypeDescriptors.strings()) + .via(row -> Preconditions.checkStateNotNull(row.getString("line")))) + .apply(TextIO.write().to(filePathPrefix)); - return PCollectionRowTuple.empty(input.getPipeline()); - } - }; + return PCollectionRowTuple.empty(input.getPipeline()); + } } @DefaultSchema(AutoValueSchema.class) @AutoValue - protected abstract static class Configuration { + public abstract static class Configuration { public static Builder builder() { return new AutoValue_WriteWordsProvider_Configuration.Builder(); } diff --git a/gradle.properties b/gradle.properties index 3923dc204272..dea5966f825d 100644 --- a/gradle.properties +++ b/gradle.properties @@ -30,8 +30,8 @@ signing.gnupg.useLegacyGpg=true # buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy. # To build a custom Beam version make sure you change it in both places, see # https://github.com/apache/beam/issues/21302. -version=2.62.0-SNAPSHOT -sdk_version=2.62.0.dev +version=2.63.0-SNAPSHOT +sdk_version=2.63.0.dev javaVersion=1.8 diff --git a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOLT.java b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOLT.java index a9ae68142778..7ea8dece31bb 100644 --- a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOLT.java +++ b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOLT.java @@ -79,11 +79,20 @@ * *

Example trigger command for specific test running on Dataflow runner: * + *

Maven + * *

  * mvn test -pl it/google-cloud-platform -am -Dtest="BigQueryIOLT#testAvroFileLoadsWriteThenRead" \
  * -Dconfiguration=medium -Dproject=[gcpProject] -DartifactBucket=[temp bucket] -DfailIfNoTests=false
  * 
* + *

Gradle + * + *

+ * ./gradlew :it:google-cloud-platform:BigQueryPerformanceTest --tests='BigQueryIOLT.testAvroFileLoadsWriteThenRead' \
+ * -Dconfiguration=medium -Dproject=[gcpProject] -DartifactBucket=[temp bucket] -DfailIfNoTests=false
+ * 
+ * *

Example trigger command for specific test and custom data configuration: * *

mvn test -pl it/google-cloud-platform -am \
@@ -172,11 +181,11 @@ public static void tearDownClass() {
                   Configuration.class), // 1 MB
               "medium",
               Configuration.fromJsonString(
-                  "{\"numRecords\":10000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":20,\"runner\":\"DataflowRunner\"}",
+                  "{\"numRecords\":10000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":20,\"runner\":\"DataflowRunner\",\"workerMachineType\":\"e2-standard-2\",\"experiments\":\"disable_runner_v2\",\"numWorkers\":\"1\",\"maxNumWorkers\":\"1\"}",
                   Configuration.class), // 10 GB
               "large",
               Configuration.fromJsonString(
-                  "{\"numRecords\":100000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":80,\"runner\":\"DataflowRunner\"}",
+                  "{\"numRecords\":100000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":80,\"runner\":\"DataflowRunner\",\"workerMachineType\":\"e2-standard-2\",\"experiments\":\"disable_runner_v2\",\"numWorkers\":\"1\",\"maxNumWorkers\":\"1\",\"numStorageWriteApiStreams\":4,\"storageWriteApiTriggeringFrequencySec\":20}",
                   Configuration.class) // 100 GB
               );
     } catch (IOException e) {
@@ -230,16 +239,19 @@ public void testWriteAndRead() throws IOException {
         writeIO =
             BigQueryIO.write()
                 .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)
+                .withNumStorageWriteApiStreams(
+                    configuration.numStorageWriteApiStreams) // control the number of streams
                 .withAvroFormatFunction(
                     new AvroFormatFn(
                         configuration.numColumns,
                         !("STORAGE_WRITE_API".equalsIgnoreCase(configuration.writeMethod))));
-
         break;
       case JSON:
         writeIO =
             BigQueryIO.write()
                 .withSuccessfulInsertsPropagation(false)
+                .withNumStorageWriteApiStreams(
+                    configuration.numStorageWriteApiStreams) // control the number of streams
                 .withFormatFunction(new JsonFormatFn(configuration.numColumns));
         break;
     }
@@ -268,6 +280,10 @@ private void testWrite(BigQueryIO.Write writeIO) throws IOException {
             .setSdk(PipelineLauncher.Sdk.JAVA)
             .setPipeline(writePipeline)
             .addParameter("runner", configuration.runner)
+            .addParameter("workerMachineType", configuration.workerMachineType)
+            .addParameter("experiments", configuration.experiments)
+            .addParameter("numWorkers", configuration.numWorkers)
+            .addParameter("maxNumWorkers", configuration.maxNumWorkers)
             .build();
 
     PipelineLauncher.LaunchInfo launchInfo = pipelineLauncher.launch(project, region, options);
@@ -304,6 +320,10 @@ private void testRead() throws IOException {
             .setSdk(PipelineLauncher.Sdk.JAVA)
             .setPipeline(readPipeline)
             .addParameter("runner", configuration.runner)
+            .addParameter("workerMachineType", configuration.workerMachineType)
+            .addParameter("experiments", configuration.experiments)
+            .addParameter("numWorkers", configuration.numWorkers)
+            .addParameter("maxNumWorkers", configuration.maxNumWorkers)
             .build();
 
     PipelineLauncher.LaunchInfo launchInfo = pipelineLauncher.launch(project, region, options);
@@ -445,12 +465,36 @@ static class Configuration extends SyntheticSourceOptions {
     /** Runner specified to run the pipeline. */
     @JsonProperty public String runner = "DirectRunner";
 
+    /** Worker machine type specified to run the pipeline with Dataflow Runner. */
+    @JsonProperty public String workerMachineType = "";
+
+    /** Experiments specified to run the pipeline. */
+    @JsonProperty public String experiments = "";
+
+    /** Number of workers to start the pipeline. Must be a positive value. */
+    @JsonProperty public String numWorkers = "1";
+
+    /** Maximum umber of workers for the pipeline. Must be a positive value. */
+    @JsonProperty public String maxNumWorkers = "1";
+
     /** BigQuery read method: DEFAULT/DIRECT_READ/EXPORT. */
     @JsonProperty public String readMethod = "DEFAULT";
 
     /** BigQuery write method: DEFAULT/FILE_LOADS/STREAMING_INSERTS/STORAGE_WRITE_API. */
     @JsonProperty public String writeMethod = "DEFAULT";
 
+    /**
+     * BigQuery number of streams for write method STORAGE_WRITE_API. 0 let's the runner determine
+     * the number of streams. Remark : max limit for open connections per hour is 10K streams.
+     */
+    @JsonProperty public int numStorageWriteApiStreams = 0;
+
+    /**
+     * BigQuery triggering frequency in second in combination with the number of streams for write
+     * method STORAGE_WRITE_API.
+     */
+    @JsonProperty public int storageWriteApiTriggeringFrequencySec = 20;
+
     /** BigQuery write format: AVRO/JSON. */
     @JsonProperty public String writeFormat = "AVRO";
   }
diff --git a/local-env-setup.sh b/local-env-setup.sh
index ba30813b2bcc..b75cf14f22c4 100755
--- a/local-env-setup.sh
+++ b/local-env-setup.sh
@@ -24,7 +24,7 @@ darwin_install_pip3_packages() {
 
 install_go_packages(){
         echo "Installing goavro"
-        go get github.com/linkedin/goavro/v2
+        go mod init beam-runtime && go get github.com/linkedin/goavro/v2
         # As we are using bash, we are assuming .bashrc exists.
         grep -qxF "export GOPATH=${PWD}/sdks/go/examples/.gogradle/project_gopath" ~/.bashrc
         gopathExists=$?
diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/metrics.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/metrics.proto
index 4ec189e4637f..33bb5ae729f8 100644
--- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/metrics.proto
+++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/metrics.proto
@@ -198,6 +198,17 @@ message MonitoringInfoSpecs {
       }]
     }];
 
+    // Represents a set of strings seen across bundles.
+    USER_BOUNDED_TRIE = 22 [(monitoring_info_spec) = {
+      urn: "beam:metric:user:bounded_trie:v1",
+      type: "beam:metrics:bounded_trie:v1",
+      required_labels: ["PTRANSFORM", "NAMESPACE", "NAME"],
+      annotations: [{
+        key: "description",
+        value: "URN utilized to report user metric."
+      }]
+    }];
+
     // General monitored state information which contains structured information
     // which does not fit into a typical metric format. See MonitoringTableData
     // for more details.
@@ -576,6 +587,12 @@ message MonitoringInfoTypeUrns {
     SET_STRING_TYPE = 11 [(org.apache.beam.model.pipeline.v1.beam_urn) =
       "beam:metrics:set_string:v1"];
 
+    // Represents a bounded trie of strings.
+    //
+    // Encoding: BoundedTrie proto
+    BOUNDED_TRIE_TYPE = 12 [(org.apache.beam.model.pipeline.v1.beam_urn) =
+      "beam:metrics:bounded_trie:v1"];
+
     // General monitored state information which contains structured information
     // which does not fit into a typical metric format. See MonitoringTableData
     // for more details.
@@ -588,6 +605,30 @@ message MonitoringInfoTypeUrns {
   }
 }
 
+
+// A single node in a BoundedTrie.
+message BoundedTrieNode {
+  // Whether this node has been truncated.
+  // A truncated leaf represents possibly many children with the same prefix.
+  bool truncated = 1;
+
+  // Children of this node.  Must be empty if truncated is true.
+  map children = 2;
+}
+
+// The message type used for encoding metrics of type bounded trie.
+message BoundedTrie {
+  // The maximum number of elements to store before truncation.
+  int32 bound = 1;
+
+  // A compact representation of all the elements in this trie.
+  BoundedTrieNode root = 2;
+
+  // A more efficient representation for metrics consisting of a single value.
+  repeated string singleton = 3;
+}
+
+
 // General monitored state information which contains structured information
 // which does not fit into a typical metric format.
 //
diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle
index d13e1c5faf6e..be39d4e0b012 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -422,3 +422,8 @@ createPipelineOptionsTableTask('Python')
 // Update the pipeline options documentation before running the tests
 test.dependsOn(generatePipelineOptionsTableJava)
 test.dependsOn(generatePipelineOptionsTablePython)
+
+// delegate spotlessApply to :runners:flink:spotlessApply
+tasks.named("spotlessApply") {
+  dependsOn ":runners:flink:spotlessApply"
+}
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index ce99958c57fd..9ca6e95ed95a 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -103,6 +103,7 @@
 import org.apache.beam.sdk.options.ExperimentalOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
+import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverride;
@@ -1252,6 +1253,8 @@ public DataflowPipelineJob run(Pipeline pipeline) {
         experiments.add("use_portable_job_submission");
       }
       options.setExperiments(ImmutableList.copyOf(experiments));
+      // Ensure that logging via the FnApi is enabled
+      options.as(SdkHarnessOptions.class).setEnableLogViaFnApi(true);
     }
 
     logWarningIfPCollectionViewHasNonDeterministicKeyCoder(pipeline);
diff --git a/runners/google-cloud-dataflow-java/worker/build.gradle b/runners/google-cloud-dataflow-java/worker/build.gradle
index 92beccd067e2..b7e6e981effe 100644
--- a/runners/google-cloud-dataflow-java/worker/build.gradle
+++ b/runners/google-cloud-dataflow-java/worker/build.gradle
@@ -54,7 +54,6 @@ def sdk_provided_project_dependencies = [
         ":runners:google-cloud-dataflow-java",
         ":sdks:java:extensions:avro",
         ":sdks:java:extensions:google-cloud-platform-core",
-        ":sdks:java:io:kafka", // For metric propagation into worker
         ":sdks:java:io:google-cloud-platform",
 ]
 
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java
index 77f867793ae2..91baefa0be4c 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java
@@ -32,7 +32,6 @@
 import java.util.Map.Entry;
 import java.util.Optional;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics;
-import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics;
 import org.apache.beam.sdk.metrics.LabeledMetricNameUtils;
 import org.apache.beam.sdk.metrics.MetricName;
 import org.apache.beam.sdk.util.HistogramData;
@@ -43,6 +42,9 @@
  * converter.
  */
 public class MetricsToPerStepNamespaceMetricsConverter {
+  // Avoids to introduce mandatory kafka-io dependency to Dataflow worker
+  // keep in sync with org.apache.beam.sdk.io.kafka.KafkaSinkMetrics.METRICS_NAMESPACE
+  public static String KAFKA_SINK_METRICS_NAMESPACE = "KafkaSink";
 
   private static Optional getParsedMetricName(
       MetricName metricName,
@@ -70,7 +72,7 @@ private static Optional convertCounterToMetricValue(
 
     if (value == 0
         || (!metricName.getNamespace().equals(BigQuerySinkMetrics.METRICS_NAMESPACE)
-            && !metricName.getNamespace().equals(KafkaSinkMetrics.METRICS_NAMESPACE))) {
+            && !metricName.getNamespace().equals(KAFKA_SINK_METRICS_NAMESPACE))) {
       return Optional.empty();
     }
 
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OutputTooLargeException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OutputTooLargeException.java
index 9f4b413841c5..acfd8a291108 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OutputTooLargeException.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OutputTooLargeException.java
@@ -22,7 +22,9 @@
 /** Indicates that an output element was too large. */
 public class OutputTooLargeException extends RuntimeException {
   public OutputTooLargeException(String reason) {
-    super(reason);
+    super(
+        reason
+            + " See https://cloud.google.com/dataflow/docs/guides/common-errors#key-commit-too-large-exception.");
   }
 
   /** Returns whether an exception was caused by a {@link OutputTooLargeException}. */
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 088a28e9b2db..0112ab4af80a 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -110,7 +110,6 @@
 import org.apache.beam.sdk.fn.JvmInitializers;
 import org.apache.beam.sdk.io.FileSystems;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics;
-import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.util.construction.CoderTranslation;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
@@ -835,10 +834,6 @@ public static void main(String[] args) throws Exception {
       enableBigQueryMetrics();
     }
 
-    if (DataflowRunner.hasExperiment(options, "enable_kafka_metrics")) {
-      KafkaSinkMetrics.setSupportKafkaMetrics(true);
-    }
-
     JvmInitializers.runBeforeProcessing(options);
     worker.startStatusPages();
     worker.start();
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillSink.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillSink.java
index 78d0c6b4550a..f83c68ab3c90 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillSink.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillSink.java
@@ -183,7 +183,9 @@ public long add(WindowedValue data) throws IOException {
               "Trying to output too large key with size "
                   + key.size()
                   + ". Limit is "
-                  + context.getMaxOutputKeyBytes());
+                  + context.getMaxOutputKeyBytes()
+                  + ". See https://cloud.google.com/dataflow/docs/guides/common-errors#key-commit-too-large-exception."
+                  + " Running with --experiments=throw_exceptions_on_large_output will instead throw an OutputTooLargeException which may be caught in user code.");
         }
       }
       if (value.size() > context.getMaxOutputValueBytes()) {
@@ -194,7 +196,9 @@ public long add(WindowedValue data) throws IOException {
               "Trying to output too large value with size "
                   + value.size()
                   + ". Limit is "
-                  + context.getMaxOutputValueBytes());
+                  + context.getMaxOutputValueBytes()
+                  + ". See https://cloud.google.com/dataflow/docs/guides/common-errors#key-commit-too-large-exception."
+                  + " Running with --experiments=throw_exceptions_on_large_output will instead throw an OutputTooLargeException which may be caught in user code.");
         }
       }
 
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/KeyCommitTooLargeException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/KeyCommitTooLargeException.java
index 090d9981309e..76228b9092b3 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/KeyCommitTooLargeException.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/KeyCommitTooLargeException.java
@@ -40,7 +40,8 @@ public static KeyCommitTooLargeException causedBy(
     message.append(
         ". This may be caused by grouping a very "
             + "large amount of data in a single window without using Combine,"
-            + " or by producing a large amount of data from a single input element.");
+            + " or by producing a large amount of data from a single input element."
+            + " See https://cloud.google.com/dataflow/docs/guides/common-errors#key-commit-too-large-exception.");
     return new KeyCommitTooLargeException(message.toString());
   }
 
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java
index 525464ef2e1f..d9fe95f3421b 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.runners.dataflow.worker.streaming;
 
+import static org.apache.beam.runners.dataflow.worker.MetricsToPerStepNamespaceMetricsConverter.KAFKA_SINK_METRICS_NAMESPACE;
 import static org.apache.beam.sdk.metrics.Metrics.THROTTLE_TIME_COUNTER_NAME;
 
 import com.google.api.services.dataflow.model.CounterStructuredName;
@@ -35,7 +36,6 @@
 import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor;
 import org.apache.beam.runners.dataflow.worker.counters.NameContext;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics;
-import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
 
 /** Contains a few of the stage specific fields. E.g. metrics container registry, counters etc. */
@@ -120,8 +120,7 @@ private void translateKnownPerWorkerCounters(List metri
     for (PerStepNamespaceMetrics perStepnamespaceMetrics : metrics) {
       if (!BigQuerySinkMetrics.METRICS_NAMESPACE.equals(
               perStepnamespaceMetrics.getMetricsNamespace())
-          && !KafkaSinkMetrics.METRICS_NAMESPACE.equals(
-              perStepnamespaceMetrics.getMetricsNamespace())) {
+          && !KAFKA_SINK_METRICS_NAMESPACE.equals(perStepnamespaceMetrics.getMetricsNamespace())) {
         continue;
       }
       for (MetricValue metric : perStepnamespaceMetrics.getMetricValues()) {
diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle
index 82eb62b9e207..a48973f65674 100644
--- a/runners/prism/java/build.gradle
+++ b/runners/prism/java/build.gradle
@@ -106,6 +106,12 @@ def sickbayTests = [
     'org.apache.beam.sdk.testing.TestStreamTest.testMultipleStreams',
     'org.apache.beam.sdk.testing.TestStreamTest.testProcessingTimeTrigger',
 
+    // GroupIntoBatchesTest tests that fail:
+    // Teststream has bad KV encodings due to using an outer context.
+    'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testInStreamingMode',
+    // ShardedKey not yet implemented.
+    'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testWithShardedKeyInGlobalWindow',
+
     // Coding error somehow: short write: reached end of stream after reading 5 bytes; 98 bytes expected
     'org.apache.beam.sdk.testing.TestStreamTest.testMultiStage',
 
@@ -181,14 +187,6 @@ def sickbayTests = [
     // Missing output due to processing time timer skew.
     'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testProcessElementSkew',
 
-    // TestStream + BundleFinalization.
-    // Tests seem to assume individual element bundles from test stream, but prism will aggregate them, preventing
-    // a subsequent firing. Tests ultimately hang until timeout.
-    // Either a test problem, or a misunderstanding of how test stream must work problem in prism.
-    // Biased to test problem, due to how they are constructed.
-    'org.apache.beam.sdk.transforms.ParDoTest$BundleFinalizationTests.testBundleFinalization',
-    'org.apache.beam.sdk.transforms.ParDoTest$BundleFinalizationTests.testBundleFinalizationWithSideInputs',
-
     // Filtered by PortableRunner tests.
     // Teardown not called in exceptions
     // https://github.com/apache/beam/issues/20372
@@ -227,15 +225,13 @@ def createPrismValidatesRunnerTask = { name, environmentType ->
       excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService'
       excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment'
 
-      // Not yet implemented in Prism
-      // https://github.com/apache/beam/issues/32211
-      excludeCategories 'org.apache.beam.sdk.testing.UsesOnWindowExpiration'
-      // https://github.com/apache/beam/issues/32929
-      excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'
+      // Not supported in Portable Java SDK yet.
+      // https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState
+      excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState'
 
-       // Not supported in Portable Java SDK yet.
-       // https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState
-       excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState'
+      // Processing time with TestStream is unreliable without being able to control
+      // SDK side time portably. Ignore these tests.
+      excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
     }
     filter {
       for (String test : sickbayTests) {
diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle
index f4e6bf740189..297facd4bc0d 100644
--- a/runners/spark/spark_runner.gradle
+++ b/runners/spark/spark_runner.gradle
@@ -345,7 +345,7 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test)
     excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment'
 
     // State and Timers
-    excludeCategories 'org.apache.beam.sdk.testing.UsesStatefulParDo'
+    excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'
     excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
     excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap'
     excludeCategories 'org.apache.beam.sdk.testing.UsesLoopingTimer'
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java
index 619d2d16173d..44f8d6df683b 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java
@@ -21,7 +21,7 @@
 import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import org.apache.beam.runners.spark.io.MicrobatchSource;
-import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet.StateAndTimers;
+import org.apache.beam.runners.spark.stateful.StateAndTimers;
 import org.apache.beam.runners.spark.translation.ValueAndCoderKryoSerializer;
 import org.apache.beam.runners.spark.translation.ValueAndCoderLazySerializable;
 import org.apache.beam.runners.spark.util.ByteArray;
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
index c24841c7dd31..b18b31a67463 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
@@ -17,6 +17,9 @@
  */
 package org.apache.beam.runners.spark.stateful;
 
+import static org.apache.beam.runners.spark.translation.TranslationUtils.checkpointIfNeeded;
+import static org.apache.beam.runners.spark.translation.TranslationUtils.getBatchDuration;
+
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -35,7 +38,6 @@
 import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
 import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine;
 import org.apache.beam.runners.core.triggers.TriggerStateMachines;
-import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.translation.ReifyTimestampsAndWindowsFunction;
 import org.apache.beam.runners.spark.translation.TranslationUtils;
@@ -60,10 +62,8 @@
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
 import org.apache.spark.api.java.JavaSparkContext$;
 import org.apache.spark.api.java.function.FlatMapFunction;
-import org.apache.spark.streaming.Duration;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
 import org.apache.spark.streaming.dstream.DStream;
@@ -100,27 +100,6 @@ public class SparkGroupAlsoByWindowViaWindowSet implements Serializable {
   private static final Logger LOG =
       LoggerFactory.getLogger(SparkGroupAlsoByWindowViaWindowSet.class);
 
-  /** State and Timers wrapper. */
-  public static class StateAndTimers implements Serializable {
-    // Serializable state for internals (namespace to state tag to coded value).
-    private final Table state;
-    private final Collection serTimers;
-
-    private StateAndTimers(
-        final Table state, final Collection timers) {
-      this.state = state;
-      this.serTimers = timers;
-    }
-
-    Table getState() {
-      return state;
-    }
-
-    Collection getTimers() {
-      return serTimers;
-    }
-  }
-
   private static class OutputWindowedValueHolder
       implements OutputWindowedValue>> {
     private final List>>> windowedValues = new ArrayList<>();
@@ -348,7 +327,7 @@ private Collection filterTimersEligibleForProcessing(
 
             // empty outputs are filtered later using DStream filtering
             final StateAndTimers updated =
-                new StateAndTimers(
+                StateAndTimers.of(
                     stateInternals.getState(),
                     SparkTimerInternals.serializeTimers(
                         timerInternals.getTimers(), timerDataCoder));
@@ -466,21 +445,6 @@ private static  TimerInternals.TimerDataCoderV2 timerDa
     return TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder());
   }
 
-  private static void checkpointIfNeeded(
-      final DStream>>> firedStream,
-      final SerializablePipelineOptions options) {
-
-    final Long checkpointDurationMillis = getBatchDuration(options);
-
-    if (checkpointDurationMillis > 0) {
-      firedStream.checkpoint(new Duration(checkpointDurationMillis));
-    }
-  }
-
-  private static Long getBatchDuration(final SerializablePipelineOptions options) {
-    return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
-  }
-
   private static  JavaDStream>>> stripStateValues(
       final DStream>>> firedStream,
       final Coder keyCoder,
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
index 5890662307fb..77ae042d81fa 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
@@ -63,7 +63,7 @@
 @SuppressWarnings({
   "nullness" // TODO(https://github.com/apache/beam/issues/20497)
 })
-class SparkStateInternals implements StateInternals {
+public class SparkStateInternals implements StateInternals {
 
   private final K key;
   // Serializable state for internals (namespace to state tag to coded value).
@@ -79,11 +79,11 @@ private SparkStateInternals(K key, Table stateTable) {
     this.stateTable = stateTable;
   }
 
-  static  SparkStateInternals forKey(K key) {
+  public static  SparkStateInternals forKey(K key) {
     return new SparkStateInternals<>(key);
   }
 
-  static  SparkStateInternals forKeyAndState(
+  public static  SparkStateInternals forKeyAndState(
       K key, Table stateTable) {
     return new SparkStateInternals<>(key, stateTable);
   }
@@ -412,7 +412,7 @@ public void put(MapKeyT key, MapValueT value) {
     @Override
     public ReadableState computeIfAbsent(
         MapKeyT key, Function mappingFunction) {
-      Map sparkMapState = readValue();
+      Map sparkMapState = readAsMap();
       MapValueT current = sparkMapState.get(key);
       if (current == null) {
         put(key, mappingFunction.apply(key));
@@ -420,9 +420,17 @@ public ReadableState computeIfAbsent(
       return ReadableStates.immediate(current);
     }
 
+    private Map readAsMap() {
+      Map mapState = readValue();
+      if (mapState == null) {
+        mapState = new HashMap<>();
+      }
+      return mapState;
+    }
+
     @Override
     public void remove(MapKeyT key) {
-      Map sparkMapState = readValue();
+      Map sparkMapState = readAsMap();
       sparkMapState.remove(key);
       writeValue(sparkMapState);
     }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
index de9820e1255c..8b647c42dd7e 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
@@ -107,7 +107,7 @@ public Collection getTimers() {
     return timers;
   }
 
-  void addTimers(Iterator timers) {
+  public void addTimers(Iterator timers) {
     while (timers.hasNext()) {
       TimerData timer = timers.next();
       this.timers.add(timer);
@@ -163,7 +163,8 @@ public void setTimer(
       Instant target,
       Instant outputTimestamp,
       TimeDomain timeDomain) {
-    throw new UnsupportedOperationException("Setting a timer by ID not yet supported.");
+    this.setTimer(
+        TimerData.of(timerId, timerFamilyId, namespace, target, outputTimestamp, timeDomain));
   }
 
   @Override
diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoderProviderRegistrar.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateAndTimers.java
similarity index 51%
rename from sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoderProviderRegistrar.java
rename to runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateAndTimers.java
index 5a187e734d66..83eaddde5532 100644
--- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoderProviderRegistrar.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateAndTimers.java
@@ -15,23 +15,31 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.sdk.io.aws.dynamodb;
+package org.apache.beam.runners.spark.stateful;
 
-import com.amazonaws.services.dynamodbv2.model.AttributeValue;
-import com.google.auto.service.AutoService;
-import java.util.List;
-import org.apache.beam.sdk.coders.CoderProvider;
-import org.apache.beam.sdk.coders.CoderProviderRegistrar;
-import org.apache.beam.sdk.coders.CoderProviders;
-import org.apache.beam.sdk.values.TypeDescriptor;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import com.google.auto.value.AutoValue;
+import java.io.Serializable;
+import java.util.Collection;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
 
-/** A {@link CoderProviderRegistrar} for standard types used with {@link DynamoDBIO}. */
-@AutoService(CoderProviderRegistrar.class)
-public class AttributeValueCoderProviderRegistrar implements CoderProviderRegistrar {
-  @Override
-  public List getCoderProviders() {
-    return ImmutableList.of(
-        CoderProviders.forCoder(TypeDescriptor.of(AttributeValue.class), AttributeValueCoder.of()));
+/** State and Timers wrapper. */
+@AutoValue
+public abstract class StateAndTimers implements Serializable {
+  public abstract Table getState();
+
+  public abstract Collection getTimers();
+
+  public static StateAndTimers of(
+      final Table state, final Collection timers) {
+    return new AutoValue_StateAndTimers.Builder().setState(state).setTimers(timers).build();
+  }
+
+  @AutoValue.Builder
+  abstract static class Builder {
+    abstract Builder setState(Table state);
+
+    abstract Builder setTimers(Collection timers);
+
+    abstract StateAndTimers build();
   }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
index 8bbcb1f2941a..34836cd6e7ae 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
@@ -31,12 +31,12 @@
 import org.joda.time.Instant;
 
 /** DoFnRunner decorator which registers {@link MetricsContainerImpl}. */
-class DoFnRunnerWithMetrics implements DoFnRunner {
+public class DoFnRunnerWithMetrics implements DoFnRunner {
   private final DoFnRunner delegate;
   private final String stepName;
   private final MetricsContainerStepMapAccumulator metricsAccum;
 
-  DoFnRunnerWithMetrics(
+  public DoFnRunnerWithMetrics(
       String stepName,
       DoFnRunner delegate,
       MetricsContainerStepMapAccumulator metricsAccum) {
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
index 62c5e2579427..1d8901ed5ffc 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
@@ -17,6 +17,9 @@
  */
 package org.apache.beam.runners.spark.translation;
 
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.Collectors;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.util.ByteArray;
 import org.apache.beam.sdk.coders.Coder;
@@ -27,6 +30,7 @@
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
 import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
@@ -49,18 +53,36 @@ public static  JavaRDD>>> groupByKeyOnly(
       @Nullable Partitioner partitioner) {
     // we use coders to convert objects in the PCollection to byte arrays, so they
     // can be transferred over the network for the shuffle.
-    JavaPairRDD pairRDD =
-        rdd.map(new ReifyTimestampsAndWindowsFunction<>())
-            .mapToPair(TranslationUtils.toPairFunction())
-            .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder));
-
-    // If no partitioner is passed, the default group by key operation is called
-    JavaPairRDD> groupedRDD =
-        (partitioner != null) ? pairRDD.groupByKey(partitioner) : pairRDD.groupByKey();
-
-    return groupedRDD
-        .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder))
-        .map(new TranslationUtils.FromPairFunction<>());
+    final JavaPairRDD pairRDD =
+        rdd.mapPartitionsToPair(
+            (Iterator>> iter) ->
+                Iterators.transform(
+                    iter,
+                    (WindowedValue> wv) -> {
+                      final K key = wv.getValue().getKey();
+                      final WindowedValue windowedValue = wv.withValue(wv.getValue().getValue());
+                      final ByteArray keyBytes =
+                          new ByteArray(CoderHelpers.toByteArray(key, keyCoder));
+                      final byte[] windowedValueBytes =
+                          CoderHelpers.toByteArray(windowedValue, wvCoder);
+                      return Tuple2.apply(keyBytes, windowedValueBytes);
+                    }));
+
+    final JavaPairRDD> combined =
+        GroupNonMergingWindowsFunctions.combineByKey(pairRDD, partitioner).cache();
+
+    return combined.mapPartitions(
+        (Iterator>> iter) ->
+            Iterators.transform(
+                iter,
+                (Tuple2> tuple) -> {
+                  final K key = CoderHelpers.fromByteArray(tuple._1().getValue(), keyCoder);
+                  final List> windowedValues =
+                      tuple._2().stream()
+                          .map(bytes -> CoderHelpers.fromByteArray(bytes, wvCoder))
+                          .collect(Collectors.toList());
+                  return KV.of(key, windowedValues);
+                }));
   }
 
   /**
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java
index 2461d5cc8d66..14630fbb0a1f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java
@@ -17,7 +17,9 @@
  */
 package org.apache.beam.runners.spark.translation;
 
+import java.util.ArrayList;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Objects;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.util.ByteArray;
@@ -41,6 +43,9 @@
 import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -259,9 +264,12 @@ private WindowedValue> decodeItem(Tuple2 item) {
   }
 
   /**
-   * Group all values with a given key for that composite key with Spark's groupByKey, dropping the
-   * Window (which must be GlobalWindow) and returning the grouped result in the appropriate global
-   * window.
+   * Groups values with a given key using Spark's combineByKey operation in the Global Window
+   * context. The window information (which must be GlobalWindow) is dropped during processing, and
+   * the grouped results are returned in the appropriate global window with the maximum timestamp.
+   *
+   * 

This implementation uses {@link JavaPairRDD#combineByKey} for better performance compared to + * {@link JavaPairRDD#groupByKey}, as it allows for local aggregation before shuffle operations. */ static JavaRDD>>> groupByKeyInGlobalWindow( @@ -269,24 +277,70 @@ JavaRDD>>> groupByKeyInGlobalWindow( Coder keyCoder, Coder valueCoder, Partitioner partitioner) { - JavaPairRDD rawKeyValues = - rdd.mapToPair( - wv -> - new Tuple2<>( - new ByteArray(CoderHelpers.toByteArray(wv.getValue().getKey(), keyCoder)), - CoderHelpers.toByteArray(wv.getValue().getValue(), valueCoder))); - - JavaPairRDD> grouped = - (partitioner == null) ? rawKeyValues.groupByKey() : rawKeyValues.groupByKey(partitioner); - return grouped.map( - kvs -> - WindowedValue.timestampedValueInGlobalWindow( - KV.of( - CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder), - Iterables.transform( - kvs._2, - encodedValue -> CoderHelpers.fromByteArray(encodedValue, valueCoder))), - GlobalWindow.INSTANCE.maxTimestamp(), - PaneInfo.ON_TIME_AND_ONLY_FIRING)); + final JavaPairRDD rawKeyValues = + rdd.mapPartitionsToPair( + (Iterator>> iter) -> + Iterators.transform( + iter, + (WindowedValue> wv) -> { + final ByteArray keyBytes = + new ByteArray(CoderHelpers.toByteArray(wv.getValue().getKey(), keyCoder)); + final byte[] valueBytes = + CoderHelpers.toByteArray(wv.getValue().getValue(), valueCoder); + return Tuple2.apply(keyBytes, valueBytes); + })); + + JavaPairRDD> combined = combineByKey(rawKeyValues, partitioner).cache(); + + return combined.mapPartitions( + (Iterator>> iter) -> + Iterators.transform( + iter, + kvs -> + WindowedValue.timestampedValueInGlobalWindow( + KV.of( + CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder), + Iterables.transform( + kvs._2(), + encodedValue -> + CoderHelpers.fromByteArray(encodedValue, valueCoder))), + GlobalWindow.INSTANCE.maxTimestamp(), + PaneInfo.ON_TIME_AND_ONLY_FIRING))); + } + + /** + * Combines values by key using Spark's {@link JavaPairRDD#combineByKey} operation. + * + * @param rawKeyValues Input RDD of key-value pairs + * @param partitioner Optional custom partitioner for data distribution + * @return RDD with values combined into Lists per key + */ + static JavaPairRDD> combineByKey( + JavaPairRDD rawKeyValues, @Nullable Partitioner partitioner) { + + final Function> createCombiner = + value -> { + List list = new ArrayList<>(); + list.add(value); + return list; + }; + + final Function2, byte[], List> mergeValues = + (list, value) -> { + list.add(value); + return list; + }; + + final Function2, List, List> mergeCombiners = + (list1, list2) -> { + list1.addAll(list2); + return list1; + }; + + if (partitioner == null) { + return rawKeyValues.combineByKey(createCombiner, mergeValues, mergeCombiners); + } + + return rawKeyValues.combineByKey(createCombiner, mergeValues, mergeCombiners, partitioner); } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkInputDataProcessor.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkInputDataProcessor.java index 0af480a2ff02..4b4d23b0c47c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkInputDataProcessor.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkInputDataProcessor.java @@ -47,7 +47,7 @@ * Processes Spark's input data iterators using Beam's {@link * org.apache.beam.runners.core.DoFnRunner}. */ -interface SparkInputDataProcessor { +public interface SparkInputDataProcessor { /** * @return {@link OutputManager} to be used by {@link org.apache.beam.runners.core.DoFnRunner} for diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java index 5487bb1be73c..bbcd74dc408b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java @@ -23,14 +23,14 @@ import org.apache.beam.sdk.transforms.DoFn; /** Holds current processing context for {@link SparkInputDataProcessor}. */ -class SparkProcessContext { +public class SparkProcessContext { private final String stepName; private final DoFn doFn; private final DoFnRunner doFnRunner; private final Iterator timerDataIterator; private final K key; - SparkProcessContext( + public SparkProcessContext( String stepName, DoFn doFn, DoFnRunner doFnRunner, diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java index 23af6f71b938..f2455e64b956 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -26,6 +26,8 @@ import org.apache.beam.runners.core.InMemoryStateInternals; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.SparkRunner; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.util.ByteArray; @@ -54,8 +56,10 @@ import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.dstream.DStream; import scala.Tuple2; /** A set of utilities to help translating Beam transformations into Spark transformations. */ @@ -258,6 +262,52 @@ public Boolean call(Tuple2, WindowedValue> input) { } } + /** + * Retrieves the batch duration in milliseconds from Spark pipeline options. + * + * @param options The serializable pipeline options containing Spark-specific settings + * @return The checkpoint duration in milliseconds as specified in SparkPipelineOptions + */ + public static Long getBatchDuration(final SerializablePipelineOptions options) { + return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis(); + } + + /** + * Reject timers {@link DoFn}. + * + * @param doFn the {@link DoFn} to possibly reject. + */ + public static void rejectTimers(DoFn doFn) { + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + if (signature.timerDeclarations().size() > 0 + || signature.timerFamilyDeclarations().size() > 0) { + throw new UnsupportedOperationException( + String.format( + "Found %s annotations on %s, but %s cannot yet be used with timers in the %s.", + DoFn.TimerId.class.getSimpleName(), + doFn.getClass().getName(), + DoFn.class.getSimpleName(), + SparkRunner.class.getSimpleName())); + } + } + + /** + * Checkpoints the given DStream if checkpointing is enabled in the pipeline options. + * + * @param dStream The DStream to be checkpointed + * @param options The SerializablePipelineOptions containing configuration settings including + * batch duration + */ + public static void checkpointIfNeeded( + final DStream dStream, final SerializablePipelineOptions options) { + + final Long checkpointDurationMillis = getBatchDuration(options); + + if (checkpointDurationMillis > 0) { + dStream.checkpoint(new Duration(checkpointDurationMillis)); + } + } + /** * Reject state and timers {@link DoFn}. * diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java new file mode 100644 index 000000000000..82557c3b972b --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import java.io.Serializable; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StatefulDoFnRunner; +import org.apache.beam.runners.core.StepContext; +import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator; +import org.apache.beam.runners.spark.stateful.SparkStateInternals; +import org.apache.beam.runners.spark.stateful.SparkTimerInternals; +import org.apache.beam.runners.spark.stateful.StateAndTimers; +import org.apache.beam.runners.spark.translation.DoFnRunnerWithMetrics; +import org.apache.beam.runners.spark.translation.SparkInputDataProcessor; +import org.apache.beam.runners.spark.translation.SparkProcessContext; +import org.apache.beam.runners.spark.util.ByteArray; +import org.apache.beam.runners.spark.util.CachedSideInputReader; +import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; +import org.apache.beam.runners.spark.util.SideInputBroadcast; +import org.apache.beam.runners.spark.util.SparkSideInputReader; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.spark.streaming.State; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; +import scala.Tuple2; +import scala.runtime.AbstractFunction3; + +/** + * A function to handle stateful processing in Apache Beam's SparkRunner. This class processes + * stateful DoFn operations by managing state updates in a Spark streaming context. + * + *

Current Implementation Status: + * + *

    + *
  • State: Fully implemented and supported through {@link SparkStateInternals} + *
  • Timers: Not supported. While {@link SparkTimerInternals} is present in the code, timer + * functionality is not yet fully implemented and operational + *
+ * + * @param The type of the key in the input KV pairs + * @param The type of the value in the input KV pairs + * @param The input type, must be a KV of KeyT and ValueT + * @param The output type produced by the DoFn + */ +@SuppressWarnings({"rawtypes", "unchecked"}) +public class ParDoStateUpdateFn, OutputT> + extends AbstractFunction3< + /*Serialized KeyT*/ ByteArray, + Option*/ byte[]>, + /*State*/ State, + List, /*Serialized WindowedValue*/ byte[]>>> + implements Serializable { + + @SuppressWarnings("unused") + private static final Logger LOG = LoggerFactory.getLogger(ParDoStateUpdateFn.class); + + private final MetricsContainerStepMapAccumulator metricsAccum; + private final String stepName; + private final DoFn doFn; + private final Coder keyCoder; + private final WindowedValue.FullWindowedValueCoder wvCoder; + private transient boolean wasSetupCalled; + private final SerializablePipelineOptions options; + private final TupleTag mainOutputTag; + private final List> additionalOutputTags; + private final Coder inputCoder; + private final Map, Coder> outputCoders; + private final Map, KV, SideInputBroadcast>> sideInputs; + private final WindowingStrategy windowingStrategy; + private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; + // for timer + private final Map watermarks; + private final List sourceIds; + private final TimerInternals.TimerDataCoderV2 timerDataCoder; + + public ParDoStateUpdateFn( + MetricsContainerStepMapAccumulator metricsAccum, + String stepName, + DoFn doFn, + Coder keyCoder, + WindowedValue.FullWindowedValueCoder wvCoder, + SerializablePipelineOptions options, + TupleTag mainOutputTag, + List> additionalOutputTags, + Coder inputCoder, + Map, Coder> outputCoders, + Map, KV, SideInputBroadcast>> sideInputs, + WindowingStrategy windowingStrategy, + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping, + Map watermarks, + List sourceIds) { + this.metricsAccum = metricsAccum; + this.stepName = stepName; + this.doFn = SerializableUtils.clone(doFn); + this.options = options; + this.mainOutputTag = mainOutputTag; + this.additionalOutputTags = additionalOutputTags; + this.keyCoder = keyCoder; + this.inputCoder = inputCoder; + this.outputCoders = outputCoders; + this.wvCoder = wvCoder; + this.sideInputs = sideInputs; + this.windowingStrategy = windowingStrategy; + this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; + this.watermarks = watermarks; + this.sourceIds = sourceIds; + this.timerDataCoder = + TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder()); + } + + @Override + public List, /*Serialized WindowedValue*/ byte[]>> + apply(ByteArray serializedKey, Option serializedValue, State state) { + if (serializedValue.isEmpty()) { + return Lists.newArrayList(); + } + + SparkStateInternals stateInternals; + final SparkTimerInternals timerInternals = + SparkTimerInternals.forStreamFromSources(sourceIds, watermarks); + final KeyT key = CoderHelpers.fromByteArray(serializedKey.getValue(), this.keyCoder); + + if (state.exists()) { + final StateAndTimers stateAndTimers = state.get(); + stateInternals = SparkStateInternals.forKeyAndState(key, stateAndTimers.getState()); + timerInternals.addTimers( + SparkTimerInternals.deserializeTimers(stateAndTimers.getTimers(), timerDataCoder)); + } else { + stateInternals = SparkStateInternals.forKey(key); + } + + final byte[] byteValue = serializedValue.get(); + final WindowedValue windowedValue = CoderHelpers.fromByteArray(byteValue, this.wvCoder); + + final WindowedValue> keyedWindowedValue = + windowedValue.withValue(KV.of(key, windowedValue.getValue())); + + if (!wasSetupCalled) { + DoFnInvokers.tryInvokeSetupFor(this.doFn, this.options.get()); + this.wasSetupCalled = true; + } + + SparkInputDataProcessor, WindowedValue>> processor = + SparkInputDataProcessor.createUnbounded(); + + final StepContext context = + new StepContext() { + @Override + public StateInternals stateInternals() { + return stateInternals; + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + }; + + DoFnRunner doFnRunner = + DoFnRunners.simpleRunner( + options.get(), + doFn, + CachedSideInputReader.of(new SparkSideInputReader(sideInputs)), + processor.getOutputManager(), + (TupleTag) mainOutputTag, + additionalOutputTags, + context, + inputCoder, + outputCoders, + windowingStrategy, + doFnSchemaInformation, + sideInputMapping); + + final Coder windowCoder = + windowingStrategy.getWindowFn().windowCoder(); + + final StatefulDoFnRunner.CleanupTimer cleanUpTimer = + new StatefulDoFnRunner.TimeInternalsCleanupTimer<>(timerInternals, windowingStrategy); + + final StatefulDoFnRunner.StateCleaner stateCleaner = + new StatefulDoFnRunner.StateInternalsStateCleaner<>(doFn, stateInternals, windowCoder); + + doFnRunner = + DoFnRunners.defaultStatefulDoFnRunner( + doFn, inputCoder, doFnRunner, context, windowingStrategy, cleanUpTimer, stateCleaner); + + DoFnRunnerWithMetrics doFnRunnerWithMetrics = + new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum); + + SparkProcessContext ctx = + new SparkProcessContext<>( + stepName, doFn, doFnRunnerWithMetrics, key, timerInternals.getTimers().iterator()); + + final Iterator>> iterator = + Lists.newArrayList(keyedWindowedValue).iterator(); + + final Iterator, WindowedValue>> outputIterator = + processor.createOutputIterator((Iterator) iterator, ctx); + + state.update( + StateAndTimers.of( + stateInternals.getState(), + SparkTimerInternals.serializeTimers(timerInternals.getTimers(), timerDataCoder))); + + final List, WindowedValue>> resultList = + Lists.newArrayList(outputIterator); + + return (List, byte[]>>) + (List) + resultList.stream() + .map( + (Tuple2, WindowedValue> e) -> { + final TupleTag tupleTag = (TupleTag) e._1(); + final Coder outputCoder = + (Coder) outputCoders.get(tupleTag); + + @SuppressWarnings("nullness") + final WindowedValue.FullWindowedValueCoder outputWindowCoder = + WindowedValue.FullWindowedValueCoder.of(outputCoder, windowCoder); + + return Tuple2.apply( + tupleTag, + CoderHelpers.toByteArray((WindowedValue) e._2(), outputWindowCoder)); + }) + .collect(Collectors.toList()); + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java new file mode 100644 index 000000000000..23bcfcb129ce --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import static org.apache.beam.runners.spark.translation.TranslationUtils.getBatchDuration; +import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectTimers; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator; +import org.apache.beam.runners.spark.stateful.StateAndTimers; +import org.apache.beam.runners.spark.translation.EvaluationContext; +import org.apache.beam.runners.spark.translation.SparkPCollectionView; +import org.apache.beam.runners.spark.translation.TransformEvaluator; +import org.apache.beam.runners.spark.translation.TranslationUtils; +import org.apache.beam.runners.spark.util.ByteArray; +import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; +import org.apache.beam.runners.spark.util.SideInputBroadcast; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.construction.ParDoTranslation; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; +import org.apache.spark.streaming.State; +import org.apache.spark.streaming.StateSpec; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaMapWithStateDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import scala.Option; +import scala.Tuple2; + +/** + * A specialized evaluator for ParDo operations in Spark Streaming context that is invoked when + * stateful streaming is detected in the DoFn. + * + *

This class is used by {@link StreamingTransformTranslator}'s ParDo evaluator to handle + * stateful streaming operations. When a DoFn contains stateful processing logic, the translation + * process routes the execution through this evaluator instead of the standard ParDo evaluator. + * + *

The evaluator manages state handling and ensures proper processing semantics for streaming + * stateful operations in the Spark runner context. + * + *

Important: This evaluator includes validation logic that rejects DoFn implementations + * containing {@code @Timer} annotations, as timer functionality is not currently supported in the + * Spark streaming context. + */ +public class StatefulStreamingParDoEvaluator + implements TransformEvaluator, OutputT>> { + + @Override + public void evaluate( + ParDo.MultiOutput, OutputT> transform, EvaluationContext context) { + final DoFn, OutputT> doFn = transform.getFn(); + final DoFnSignature signature = DoFnSignatures.signatureForDoFn(doFn); + + rejectTimers(doFn); + checkArgument( + !signature.processElement().isSplittable(), + "Splittable DoFn not yet supported in streaming mode: %s", + doFn); + checkState( + signature.onWindowExpiration() == null, "onWindowExpiration is not supported: %s", doFn); + + // options, PCollectionView, WindowingStrategy + final SerializablePipelineOptions options = context.getSerializableOptions(); + final SparkPCollectionView pviews = context.getPViews(); + final WindowingStrategy windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + + final KvCoder inputCoder = + (KvCoder) context.getInput(transform).getCoder(); + Map, Coder> outputCoders = context.getOutputCoders(); + JavaPairDStream, WindowedValue> all; + + final UnboundedDataset> unboundedDataset = + (UnboundedDataset>) context.borrowDataset(transform); + + final JavaDStream>> dStream = unboundedDataset.getDStream(); + + final DoFnSchemaInformation doFnSchemaInformation = + ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); + + final Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + + final String stepName = context.getCurrentTransform().getFullName(); + + final WindowFn windowFn = windowingStrategy.getWindowFn(); + + final List sourceIds = unboundedDataset.getStreamSources(); + + // key, value coder + final Coder keyCoder = inputCoder.getKeyCoder(); + final Coder valueCoder = inputCoder.getValueCoder(); + + final WindowedValue.FullWindowedValueCoder wvCoder = + WindowedValue.FullWindowedValueCoder.of(valueCoder, windowFn.windowCoder()); + + final MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance(); + final Map, KV, SideInputBroadcast>> sideInputs = + TranslationUtils.getSideInputs( + transform.getSideInputs().values(), context.getSparkContext(), pviews); + + // Original code used multiple map operations (.map -> .mapToPair -> .mapToPair) + // which created intermediate RDDs for each transformation. + // Changed to use mapPartitionsToPair to: + // 1. Reduce the number of RDD creations by combining multiple operations + // 2. Process data in batches (partitions) rather than element by element + // 3. Improve performance by reducing serialization/deserialization overhead + // 4. Minimize the number of function objects created during execution + final JavaPairDStream< + /*Serialized KeyT*/ ByteArray, /*Serialized WindowedValue*/ byte[]> + serializedDStream = + dStream.mapPartitionsToPair( + (Iterator>> iter) -> + Iterators.transform( + iter, + (WindowedValue> windowedKV) -> { + final KeyT key = windowedKV.getValue().getKey(); + final WindowedValue windowedValue = + windowedKV.withValue(windowedKV.getValue().getValue()); + final ByteArray keyBytes = + new ByteArray(CoderHelpers.toByteArray(key, keyCoder)); + final byte[] valueBytes = + CoderHelpers.toByteArray(windowedValue, wvCoder); + return Tuple2.apply(keyBytes, valueBytes); + })); + + final Map watermarks = + GlobalWatermarkHolder.get(getBatchDuration(options)); + + @SuppressWarnings({"rawtypes", "unchecked"}) + final JavaMapWithStateDStream< + ByteArray, Option, State, List, byte[]>>> + processedPairDStream = + serializedDStream.mapWithState( + StateSpec.function( + new ParDoStateUpdateFn<>( + metricsAccum, + stepName, + doFn, + keyCoder, + (WindowedValue.FullWindowedValueCoder) wvCoder, + options, + transform.getMainOutputTag(), + transform.getAdditionalOutputTags().getAll(), + inputCoder, + outputCoders, + sideInputs, + windowingStrategy, + doFnSchemaInformation, + sideInputMapping, + watermarks, + sourceIds))); + + all = + processedPairDStream.flatMapToPair( + (List, byte[]>> list) -> + Iterators.transform( + list.iterator(), + (Tuple2, byte[]> tuple) -> { + final Coder outputCoder = outputCoders.get(tuple._1()); + @SuppressWarnings("nullness") + final WindowedValue windowedValue = + CoderHelpers.fromByteArray( + tuple._2(), + WindowedValue.FullWindowedValueCoder.of( + outputCoder, windowFn.windowCoder())); + return Tuple2.apply(tuple._1(), windowedValue); + })); + + Map, PCollection> outputs = context.getOutputs(transform); + if (hasMultipleOutputs(outputs)) { + // Caching can cause Serialization, we need to code to bytes + // more details in https://issues.apache.org/jira/browse/BEAM-2669 + Map, Coder>> coderMap = + TranslationUtils.getTupleTagCoders(outputs); + all = + all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap)) + .cache() + .mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap)); + + for (Map.Entry, PCollection> output : outputs.entrySet()) { + @SuppressWarnings({"unchecked", "rawtypes"}) + JavaPairDStream, WindowedValue> filtered = + all.filter(new TranslationUtils.TupleTagFilter(output.getKey())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaDStream> values = + (JavaDStream>) + (JavaDStream) TranslationUtils.dStreamValues(filtered); + context.putDataset(output.getValue(), new UnboundedDataset<>(values, sourceIds)); + } + } else { + @SuppressWarnings("unchecked") + final JavaDStream> values = + (JavaDStream>) (JavaDStream) TranslationUtils.dStreamValues(all); + + context.putDataset( + Iterables.getOnlyElement(outputs.entrySet()).getValue(), + new UnboundedDataset<>(values, sourceIds)); + } + } + + @Override + public String toNativeString() { + return "mapPartitions(new ())"; + } + + private boolean hasMultipleOutputs(Map, PCollection> outputs) { + return outputs.size() > 1; + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 5be8e718dec6..539f8ff3efe6 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.spark.translation.streaming; -import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; @@ -65,6 +64,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -434,11 +434,27 @@ private static TransformEvaluator transform, final EvaluationContext context) { final DoFn doFn = transform.getFn(); + final DoFnSignature signature = DoFnSignatures.signatureForDoFn(doFn); checkArgument( - !DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable(), + !signature.processElement().isSplittable(), "Splittable DoFn not yet supported in streaming mode: %s", doFn); - rejectStateAndTimers(doFn); + checkState( + signature.onWindowExpiration() == null, + "onWindowExpiration is not supported: %s", + doFn); + + boolean stateful = + signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0; + + if (stateful) { + final StatefulStreamingParDoEvaluator delegate = + new StatefulStreamingParDoEvaluator<>(); + + delegate.evaluate((ParDo.MultiOutput) transform, context); + return; + } + final SerializablePipelineOptions options = context.getSerializableOptions(); final SparkPCollectionView pviews = context.getPViews(); final WindowingStrategy windowingStrategy = diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java index ed7bc078564e..fd299924af91 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java @@ -18,12 +18,6 @@ package org.apache.beam.runners.spark.translation; import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import java.util.Arrays; import java.util.Iterator; @@ -45,9 +39,6 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Bytes; -import org.apache.spark.Partitioner; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Assert; @@ -121,54 +112,6 @@ public void testGbkIteratorValuesCannotBeReiterated() throws Coder.NonDeterminis } } - @Test - @SuppressWarnings({"rawtypes", "unchecked"}) - public void testGroupByKeyInGlobalWindowWithPartitioner() { - // mocking - Partitioner mockPartitioner = mock(Partitioner.class); - JavaRDD mockRdd = mock(JavaRDD.class); - Coder mockKeyCoder = mock(Coder.class); - Coder mockValueCoder = mock(Coder.class); - JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class); - JavaPairRDD mockGrouped = mock(JavaPairRDD.class); - - when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues); - when(mockRawKeyValues.groupByKey(any(Partitioner.class))) - .thenAnswer( - invocation -> { - Partitioner partitioner = invocation.getArgument(0); - assertEquals(partitioner, mockPartitioner); - return mockGrouped; - }); - when(mockGrouped.map(any())).thenReturn(mock(JavaRDD.class)); - - GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow( - mockRdd, mockKeyCoder, mockValueCoder, mockPartitioner); - - verify(mockRawKeyValues, never()).groupByKey(); - verify(mockRawKeyValues, times(1)).groupByKey(any(Partitioner.class)); - } - - @Test - @SuppressWarnings({"rawtypes", "unchecked"}) - public void testGroupByKeyInGlobalWindowWithoutPartitioner() { - // mocking - JavaRDD mockRdd = mock(JavaRDD.class); - Coder mockKeyCoder = mock(Coder.class); - Coder mockValueCoder = mock(Coder.class); - JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class); - JavaPairRDD mockGrouped = mock(JavaPairRDD.class); - - when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues); - when(mockRawKeyValues.groupByKey()).thenReturn(mockGrouped); - - GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow( - mockRdd, mockKeyCoder, mockValueCoder, null); - - verify(mockRawKeyValues, times(1)).groupByKey(); - verify(mockRawKeyValues, never()).groupByKey(any(Partitioner.class)); - } - private GroupByKeyIterator createGbkIterator() throws Coder.NonDeterministicException { return createGbkIterator( diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java index a3d7724e4363..243f3a3e533f 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java @@ -527,7 +527,7 @@ public void process(ProcessContext context) { } } - private static PipelineOptions streamingOptions() { + static PipelineOptions streamingOptions() { PipelineOptions options = TestPipeline.testingPipelineOptions(); options.as(TestSparkPipelineOptions.class).setStreaming(true); return options; diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java new file mode 100644 index 000000000000..e1f000d16675 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import static org.apache.beam.runners.spark.translation.streaming.CreateStreamTest.streamingOptions; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.StreamingTest; +import org.apache.beam.runners.spark.io.CreateStream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +@SuppressWarnings({"unchecked", "unused"}) +public class StatefulStreamingParDoEvaluatorTest { + + @Rule public final transient TestPipeline p = TestPipeline.fromOptions(streamingOptions()); + + private PTransform>> createStreamingSource( + Pipeline pipeline) { + Instant instant = new Instant(0); + final KvCoder coder = KvCoder.of(VarIntCoder.of(), VarIntCoder.of()); + final Duration batchDuration = batchDuration(pipeline); + return CreateStream.of(coder, batchDuration) + .emptyBatch() + .advanceWatermarkForNextBatch(instant) + .nextBatch( + TimestampedValue.of(KV.of(1, 1), instant), + TimestampedValue.of(KV.of(1, 2), instant), + TimestampedValue.of(KV.of(1, 3), instant)) + .advanceWatermarkForNextBatch(instant.plus(Duration.standardSeconds(1L))) + .nextBatch( + TimestampedValue.of(KV.of(2, 4), instant.plus(Duration.standardSeconds(1L))), + TimestampedValue.of(KV.of(2, 5), instant.plus(Duration.standardSeconds(1L))), + TimestampedValue.of(KV.of(2, 6), instant.plus(Duration.standardSeconds(1L)))) + .advanceNextBatchWatermarkToInfinity(); + } + + private PTransform>> createStreamingSource( + Pipeline pipeline, int iterCount) { + Instant instant = new Instant(0); + final KvCoder coder = KvCoder.of(VarIntCoder.of(), VarIntCoder.of()); + final Duration batchDuration = batchDuration(pipeline); + + CreateStream> createStream = + CreateStream.of(coder, batchDuration).emptyBatch().advanceWatermarkForNextBatch(instant); + + int value = 1; + for (int i = 0; i < iterCount; i++) { + createStream = + createStream.nextBatch( + TimestampedValue.of(KV.of(1, value++), instant), + TimestampedValue.of(KV.of(1, value++), instant), + TimestampedValue.of(KV.of(1, value++), instant)); + + instant = instant.plus(Duration.standardSeconds(1L)); + createStream = createStream.advanceWatermarkForNextBatch(instant); + + createStream = + createStream.nextBatch( + TimestampedValue.of(KV.of(2, value++), instant), + TimestampedValue.of(KV.of(2, value++), instant), + TimestampedValue.of(KV.of(2, value++), instant)); + + instant = instant.plus(Duration.standardSeconds(1L)); + createStream = createStream.advanceWatermarkForNextBatch(instant); + } + + return createStream.advanceNextBatchWatermarkToInfinity(); + } + + private static class StatefulWithTimerDoFn extends DoFn { + @StateId("some-state") + private final StateSpec> someStringStateSpec = + StateSpecs.value(StringUtf8Coder.of()); + + @TimerId("some-timer") + private final TimerSpec someTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + @ProcessElement + public void process( + @Element InputT element, @StateId("some-state") ValueState someStringStage) { + // ignore... + } + + @OnTimer("some-timer") + public void onTimer() { + // ignore... + } + } + + private static class StatefulDoFn extends DoFn, KV> { + + @StateId("test-state") + private final StateSpec> testState = StateSpecs.value(); + + @ProcessElement + public void process( + @Element KV element, + @StateId("test-state") ValueState testState, + OutputReceiver> output) { + final Integer value = element.getValue(); + final Integer currentState = firstNonNull(testState.read(), 0); + final Integer newState = currentState + value; + testState.write(newState); + + final KV result = KV.of(element.getKey(), newState); + output.output(result); + } + } + + @Category(StreamingTest.class) + @Test + public void shouldRejectTimer() { + p.apply(createStreamingSource(p)).apply(ParDo.of(new StatefulWithTimerDoFn<>())); + + final UnsupportedOperationException exception = + assertThrows(UnsupportedOperationException.class, p::run); + + assertEquals( + "Found TimerId annotations on " + + StatefulWithTimerDoFn.class.getName() + + ", but DoFn cannot yet be used with timers in the SparkRunner.", + exception.getMessage()); + } + + @Category(StreamingTest.class) + @Test + public void shouldProcessGlobalWidowStatefulParDo() { + final PCollection> result = + p.apply(createStreamingSource(p)).apply(ParDo.of(new StatefulDoFn())); + + PAssert.that(result) + .containsInAnyOrder( + // key 1 + KV.of(1, 1), // 1 + KV.of(1, 3), // 1 + 2 + KV.of(1, 6), // 3 + 3 + // key 2 + KV.of(2, 4), // 4 + KV.of(2, 9), // 4 + 5 + KV.of(2, 15)); // 9 + 6 + + p.run().waitUntilFinish(); + } + + @Category(StreamingTest.class) + @Test + public void shouldProcessWindowedStatefulParDo() { + final PCollection> result = + p.apply(createStreamingSource(p, 2)) + .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1L)))) + .apply(ParDo.of(new StatefulDoFn())); + + PAssert.that(result) + .containsInAnyOrder( + // Windowed Key 1 + KV.of(1, 1), // 1 + KV.of(1, 3), // 1 + 2 + KV.of(1, 6), // 3 + 3 + + // Windowed Key 2 + KV.of(2, 4), // 4 + KV.of(2, 9), // 4 + 5 + KV.of(2, 15), // 9 + 6 + + // Windowed Key 1 + KV.of(1, 7), // 7 + KV.of(1, 15), // 7 + 8 + KV.of(1, 24), // 15 + 9 + + // Windowed Key 2 + KV.of(2, 10), // 10 + KV.of(2, 21), // 10 + 11 + KV.of(2, 33) // 21 + 12 + ); + + p.run().waitUntilFinish(); + } + + private Duration batchDuration(Pipeline pipeline) { + return Duration.millis( + pipeline.getOptions().as(SparkPipelineOptions.class).getBatchIntervalMillis()); + } +} diff --git a/scripts/ci/pr-bot/shared/checks.ts b/scripts/ci/pr-bot/shared/checks.ts index f27830a1dc29..187ff5771f96 100644 --- a/scripts/ci/pr-bot/shared/checks.ts +++ b/scripts/ci/pr-bot/shared/checks.ts @@ -40,7 +40,8 @@ export async function getChecksStatus( } if ( mostRecentChecks[i].conclusion != "success" && - mostRecentChecks[i].conclusion != "skipped" + mostRecentChecks[i].conclusion != "skipped" && + mostRecentChecks[i].conclusion != "neutral" ) { checkStatus.succeeded = false; } diff --git a/sdks/go.mod b/sdks/go.mod index dc380debf3ed..27ec2236c9ca 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -23,18 +23,18 @@ module github.com/apache/beam/sdks/v2 go 1.21.0 require ( - cloud.google.com/go/bigquery v1.64.0 + cloud.google.com/go/bigquery v1.65.0 cloud.google.com/go/bigtable v1.33.0 cloud.google.com/go/datastore v1.20.0 - cloud.google.com/go/profiler v0.4.1 + cloud.google.com/go/profiler v0.4.2 cloud.google.com/go/pubsub v1.45.3 cloud.google.com/go/spanner v1.73.0 - cloud.google.com/go/storage v1.47.0 - github.com/aws/aws-sdk-go-v2 v1.32.6 - github.com/aws/aws-sdk-go-v2/config v1.28.4 - github.com/aws/aws-sdk-go-v2/credentials v1.17.47 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.38 - github.com/aws/aws-sdk-go-v2/service/s3 v1.67.0 + cloud.google.com/go/storage v1.48.0 + github.com/aws/aws-sdk-go-v2 v1.32.7 + github.com/aws/aws-sdk-go-v2/config v1.28.7 + github.com/aws/aws-sdk-go-v2/credentials v1.17.48 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 + github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 github.com/aws/smithy-go v1.22.1 github.com/docker/go-connections v0.5.0 github.com/dustin/go-humanize v1.0.1 @@ -44,7 +44,7 @@ require ( github.com/johannesboyne/gofakes3 v0.0.0-20221110173912-32fb85c5aed6 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 - github.com/nats-io/nats-server/v2 v2.10.22 + github.com/nats-io/nats-server/v2 v2.10.23 github.com/nats-io/nats.go v1.37.0 github.com/proullon/ramsql v0.1.4 github.com/spf13/cobra v1.8.1 @@ -53,15 +53,15 @@ require ( github.com/xitongsys/parquet-go v1.6.2 github.com/xitongsys/parquet-go-source v0.0.0-20220315005136-aec0fe3e777c go.mongodb.org/mongo-driver v1.17.1 - golang.org/x/net v0.31.0 + golang.org/x/net v0.33.0 golang.org/x/oauth2 v0.24.0 - golang.org/x/sync v0.9.0 - golang.org/x/sys v0.27.0 - golang.org/x/text v0.20.0 - google.golang.org/api v0.210.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.28.0 + golang.org/x/text v0.21.0 + google.golang.org/api v0.214.0 google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 - google.golang.org/grpc v1.67.1 - google.golang.org/protobuf v1.35.2 + google.golang.org/grpc v1.67.2 + google.golang.org/protobuf v1.36.0 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -75,7 +75,7 @@ require ( require ( cel.dev/expr v0.16.1 // indirect - cloud.google.com/go/auth v0.11.0 // indirect + cloud.google.com/go/auth v0.13.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/monitoring v1.21.2 // indirect dario.cat/mergo v1.0.0 // indirect @@ -98,7 +98,7 @@ require ( github.com/moby/sys/user v0.1.0 // indirect github.com/moby/sys/userns v0.1.0 // indirect github.com/nats-io/jwt/v2 v2.5.8 // indirect - github.com/nats-io/nkeys v0.4.7 // indirect + github.com/nats-io/nkeys v0.4.8 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect @@ -123,7 +123,7 @@ require ( require ( cloud.google.com/go v0.116.0 // indirect - cloud.google.com/go/compute/metadata v0.5.2 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/iam v1.2.2 // indirect cloud.google.com/go/longrunning v0.6.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect @@ -131,25 +131,25 @@ require ( github.com/apache/arrow/go/arrow v0.0.0-20200730104253-651201b0f516 // indirect github.com/apache/thrift v0.17.0 // indirect github.com/aws/aws-sdk-go v1.34.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.26 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.8 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.3 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect - github.com/docker/docker v27.3.1+incompatible // but required to resolve issue docker has with go1.20 + github.com/docker/docker v27.4.1+incompatible // but required to resolve issue docker has with go1.20 github.com/docker/go-units v0.5.0 // indirect github.com/envoyproxy/go-control-plane v0.13.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect @@ -159,7 +159,7 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v23.5.26+incompatible // indirect - github.com/google/pprof v0.0.0-20240528025155-186aa0362fba // indirect + github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect github.com/google/renameio/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect @@ -190,10 +190,10 @@ require ( github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect go.opencensus.io v0.24.0 // indirect - golang.org/x/crypto v0.29.0 // indirect + golang.org/x/crypto v0.31.0 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/tools v0.24.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20241113202542-65e8d215514f // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 // indirect ) diff --git a/sdks/go.sum b/sdks/go.sum index eae2e9ad22b6..907756d3e05e 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -101,8 +101,8 @@ cloud.google.com/go/assuredworkloads v1.7.0/go.mod h1:z/736/oNmtGAyU47reJgGN+KVo cloud.google.com/go/assuredworkloads v1.8.0/go.mod h1:AsX2cqyNCOvEQC8RMPnoc0yEarXQk6WEKkxYfL6kGIo= cloud.google.com/go/assuredworkloads v1.9.0/go.mod h1:kFuI1P78bplYtT77Tb1hi0FMxM0vVpRC7VVoJC3ZoT0= cloud.google.com/go/assuredworkloads v1.10.0/go.mod h1:kwdUQuXcedVdsIaKgKTp9t0UJkE5+PAVNhdQm4ZVq2E= -cloud.google.com/go/auth v0.11.0 h1:Ic5SZz2lsvbYcWT5dfjNWgw6tTlGi2Wc8hyQSC9BstA= -cloud.google.com/go/auth v0.11.0/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= +cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs= +cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q= cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU= cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= cloud.google.com/go/automl v1.5.0/go.mod h1:34EjfoFGMZ5sgJ9EoLsRtdPSNZLcfflJR39VbVNS2M0= @@ -133,8 +133,8 @@ cloud.google.com/go/bigquery v1.47.0/go.mod h1:sA9XOgy0A8vQK9+MWhEQTY6Tix87M/Zur cloud.google.com/go/bigquery v1.48.0/go.mod h1:QAwSz+ipNgfL5jxiaK7weyOhzdoAy1zFm0Nf1fysJac= cloud.google.com/go/bigquery v1.49.0/go.mod h1:Sv8hMmTFFYBlt/ftw2uN6dFdQPzBlREY9yBh7Oy7/4Q= cloud.google.com/go/bigquery v1.50.0/go.mod h1:YrleYEh2pSEbgTBZYMJ5SuSr0ML3ypjRB1zgf7pvQLU= -cloud.google.com/go/bigquery v1.64.0 h1:vSSZisNyhr2ioJE1OuYBQrnrpB7pIhRQm4jfjc7E/js= -cloud.google.com/go/bigquery v1.64.0/go.mod h1:gy8Ooz6HF7QmA+TRtX8tZmXBKH5mCFBwUApGAb3zI7Y= +cloud.google.com/go/bigquery v1.65.0 h1:ZZ1EOJMHTYf6R9lhxIXZJic1qBD4/x9loBIS+82moUs= +cloud.google.com/go/bigquery v1.65.0/go.mod h1:9WXejQ9s5YkTW4ryDYzKXBooL78u5+akWGXgJqQkY6A= cloud.google.com/go/bigtable v1.33.0 h1:2BDaWLRAwXO14DJL/u8crbV2oUbMZkIa2eGq8Yao1bk= cloud.google.com/go/bigtable v1.33.0/go.mod h1:HtpnH4g25VT1pejHRtInlFPnN5sjTxbQlsYBjh9t5l0= cloud.google.com/go/billing v1.4.0/go.mod h1:g9IdKBEFlItS8bTtlrZdVLWSSdSyFUZKXNS02zKMOZY= @@ -188,8 +188,8 @@ cloud.google.com/go/compute/metadata v0.1.0/go.mod h1:Z1VN+bulIf6bt4P/C37K4DyZYZ cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.2.1/go.mod h1:jgHgmJd2RKBGzXqF5LR2EZMGxBkeanZ9wwa75XHJgOM= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= -cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo= -cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/contactcenterinsights v1.3.0/go.mod h1:Eu2oemoePuEFc/xKFPjbTuPSj0fYJcPls9TFlPNnHHY= cloud.google.com/go/contactcenterinsights v1.4.0/go.mod h1:L2YzkGbPsv+vMQMCADxJoT9YiTTnSEd6fEvCeHTYVck= cloud.google.com/go/contactcenterinsights v1.6.0/go.mod h1:IIDlT6CLcDoyv79kDv8iWxMSTZhLxSCofVV5W6YFM/w= @@ -441,8 +441,8 @@ cloud.google.com/go/privatecatalog v0.5.0/go.mod h1:XgosMUvvPyxDjAVNDYxJ7wBW8//h cloud.google.com/go/privatecatalog v0.6.0/go.mod h1:i/fbkZR0hLN29eEWiiwue8Pb+GforiEIBnV9yrRUOKI= cloud.google.com/go/privatecatalog v0.7.0/go.mod h1:2s5ssIFO69F5csTXcwBP7NPFTZvps26xGzvQ2PQaBYg= cloud.google.com/go/privatecatalog v0.8.0/go.mod h1:nQ6pfaegeDAq/Q5lrfCQzQLhubPiZhSaNhIgfJlnIXs= -cloud.google.com/go/profiler v0.4.1 h1:Q7+lOvikTGMJ/IAWocpYYGit4SIIoILmVZfEEWTORSY= -cloud.google.com/go/profiler v0.4.1/go.mod h1:LBrtEX6nbvhv1w/e5CPZmX9ajGG9BGLtGbv56Tg4SHs= +cloud.google.com/go/profiler v0.4.2 h1:KojCmZ+bEPIQrd7bo2UFvZ2xUPLHl55KzHl7iaR4V2I= +cloud.google.com/go/profiler v0.4.2/go.mod h1:7GcWzs9deJHHdJ5J9V1DzKQ9JoIoTGhezwlLbwkOoCs= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= @@ -561,8 +561,8 @@ cloud.google.com/go/storage v1.23.0/go.mod h1:vOEEDNFnciUMhBeT6hsJIn3ieU5cFRmzeL cloud.google.com/go/storage v1.27.0/go.mod h1:x9DOL8TK/ygDUMieqwfhdpQryTeEkhGKMi80i/iqR2s= cloud.google.com/go/storage v1.28.1/go.mod h1:Qnisd4CqDdo6BGs2AD5LLnEsmSQ80wQ5ogcBBKhU86Y= cloud.google.com/go/storage v1.29.0/go.mod h1:4puEjyTKnku6gfKoTfNOU/W+a9JyuVNxjpS5GBrB8h4= -cloud.google.com/go/storage v1.47.0 h1:ajqgt30fnOMmLfWfu1PWcb+V9Dxz6n+9WKjdNg5R4HM= -cloud.google.com/go/storage v1.47.0/go.mod h1:Ks0vP374w0PW6jOUameJbapbQKXqkjGd/OJRp2fb9IQ= +cloud.google.com/go/storage v1.48.0 h1:FhBDHACbVtdPx7S/AbcKujPWiHvfO6F8OXGgCEbB2+o= +cloud.google.com/go/storage v1.48.0/go.mod h1:aFoDYNMAjv67lp+xcuZqjUKv/ctmplzQ3wJgodA7b+M= cloud.google.com/go/storagetransfer v1.5.0/go.mod h1:dxNzUopWy7RQevYFHewchb29POFv3/AaBgnhqzqiK0w= cloud.google.com/go/storagetransfer v1.6.0/go.mod h1:y77xm4CQV/ZhFZH75PLEXY0ROiS7Gh6pSKrM8dJyg6I= cloud.google.com/go/storagetransfer v1.7.0/go.mod h1:8Giuj1QNb1kfLAiWM1bN6dHzfdlDAVC9rv9abHot2W4= @@ -689,53 +689,53 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= -github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= -github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 h1:pT3hpW0cOHRJx8Y0DfJUEQuqPild8jRGmSFmBgvydr0= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA= +github.com/aws/aws-sdk-go-v2 v1.32.7 h1:ky5o35oENWi0JYWUZkB7WYvVPP+bcRF5/Iq7JWSb5Rw= +github.com/aws/aws-sdk-go-v2 v1.32.7/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= -github.com/aws/aws-sdk-go-v2/config v1.28.4 h1:qgD0MKmkIzZR2DrAjWJcI9UkndjR+8f6sjUQvXh0mb0= -github.com/aws/aws-sdk-go-v2/config v1.28.4/go.mod h1:LgnWnNzHZw4MLplSyEGia0WgJ/kCGD86zGCjvNpehJs= +github.com/aws/aws-sdk-go-v2/config v1.28.7 h1:GduUnoTXlhkgnxTD93g1nv4tVPILbdNQOzav+Wpg7AE= +github.com/aws/aws-sdk-go-v2/config v1.28.7/go.mod h1:vZGX6GVkIE8uECSUHB6MWAUsd4ZcG2Yq/dMa4refR3M= github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc= -github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw= -github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w= +github.com/aws/aws-sdk-go-v2/credentials v1.17.48 h1:IYdLD1qTJ0zanRavulofmqut4afs45mOWEI+MzZtTfQ= +github.com/aws/aws-sdk-go-v2/credentials v1.17.48/go.mod h1:tOscxHN3CGmuX9idQ3+qbkzrjVIx32lqDSU1/0d/qXs= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 h1:kqOrpojG71DxJm/KDPO+Z/y1phm1JlC8/iT+5XRmAn8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22/go.mod h1:NtSFajXVVL8TA2QNngagVZmUtXciyrHOt7xgz4faS/M= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.38 h1:xN0PViSptTHJ7QIKyWeWntuTCZoejutTPfhsZIoMDy0= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.38/go.mod h1:orUzUoWBICDyc+hz49KpySb3sa2Tw3c0IaFqrH4c4dg= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 h1:iLdpkYZ4cXIQMO7ud+cqMWR1xK5ESbt1rvN77tRi1BY= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43/go.mod h1:OgbsKPAswXDd5kxnR4vZov69p3oYjbvUyIRBAAV0y9o= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26/go.mod h1:FR8f4turZtNy6baO0KJ5FJUmXH/cSkI9fOngs0yl6mA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 h1:zXFLuEuMMUOvEARXFUVJdfqZ4bvvSgdGRq/ATcrQxzM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26/go.mod h1:3o2Wpy0bogG1kyOPrgkXA8pgIfEEv0+m19O9D5+W8y8= github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23 h1:1SZBDiRzzs3sNhOMVApyWPduWYGAX0imGy06XiBnCAM= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23/go.mod h1:i9TkxgbZmHVh2S0La6CAXtnyFhlCX/pJ0JsOvBAS6Mk= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.26 h1:GeNJsIFHB+WW5ap2Tec4K6dzcVTsRbsT1Lra46Hv9ME= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.26/go.mod h1:zfgMpwHDXX2WGoG84xG2H+ZlPTkJUU4YUvx2svLQYWo= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4 h1:aaPpoG15S2qHkWm4KlEyF01zovK1nW4BBbyXuHNSE90= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4/go.mod h1:eD9gS2EARTKgGr/W5xwgY/ik9z/zqpW+m/xOQbVxrMk= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.7 h1:tB4tNw83KcajNAzaIMhkhVI2Nt8fAZd5A5ro113FEMY= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.7/go.mod h1:lvpyBGkZ3tZ9iSsUIcC2EWp+0ywa7aK3BLT+FwZi+mQ= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.7 h1:8eUsivBQzZHqe/3FE+cqwfH+0p5Jo8PFM/QYQSmeZ+M= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.7/go.mod h1:kLPQvGUmxn/fqiCrDeohwG33bq2pQpGeY62yRO6Nrh0= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4 h1:E5ZAVOmI2apR8ADb72Q63KqwwwdW1XcMeXIlrZ1Psjg= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4/go.mod h1:wezzqVUOVVdk+2Z/JzQT4NxAU0NbhRe5W8pIE72jsWI= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.7 h1:Hi0KGbrnr57bEHWM0bJ1QcBzxLrL/k2DHvGYhb8+W1w= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.7/go.mod h1:wKNgWgExdjjrm4qvfbTorkvocEstaoDl4WCvGfeCy9c= github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32yTv8GpBquCApZEycDLunI= -github.com/aws/aws-sdk-go-v2/service/s3 v1.67.0 h1:SwaJ0w0MOp0pBTIKTamLVeTKD+iOWyNJRdJ2KCQRg6Q= -github.com/aws/aws-sdk-go-v2/service/s3 v1.67.0/go.mod h1:TMhLIyRIyoGVlaEMAt+ITMbwskSTpcGsCPDq91/ihY0= +github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 h1:SAfh4pNx5LuTafKKWR02Y+hL3A+3TX8cTKG1OIAJaBk= +github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0/go.mod h1:r+xl5yzMk9083rMR+sJ5TYj9Tihvf/l1oxzZXDgGj2Q= github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM= -github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw= -github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.8 h1:CvuUmnXI7ebaUAhbJcDy9YQx8wHR69eZ9I7q5hszt/g= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.8/go.mod h1:XDeGv1opzwm8ubxddF0cgqkZWsyOtw4lr6dxwmb6YQg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.7 h1:F2rBfNAL5UyswqoeWv9zs74N/NanhK16ydHW1pahX6E= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.7/go.mod h1:JfyQ0g2JG8+Krq0EuZNnRwX0mU0HrwY/tG6JNfcqh4k= github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.3 h1:Xgv/hyNgvLda/M9l9qxXc4UFSgppnRczLxlMs5Ae/QY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.3/go.mod h1:5Gn+d+VaaRgsjewpMvGazt0WfcFO+Md4wLOuBfGR9Bc= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= @@ -787,8 +787,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v27.3.1+incompatible h1:KttF0XoteNTicmUtBO0L2tP+J7FGRFTjaEF4k6WdhfI= -github.com/docker/docker v27.3.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v27.4.1+incompatible h1:ZJvcY7gfwHn1JF48PfbyXg7Jyt9ZCWDW+GGXOIxEwp4= +github.com/docker/docker v27.4.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -948,8 +948,8 @@ github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20240528025155-186aa0362fba h1:ql1qNgCyOB7iAEk8JTNM+zJrgIbnyCKX/wdlyPufP5g= -github.com/google/pprof v0.0.0-20240528025155-186aa0362fba/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg= github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4= @@ -1085,12 +1085,12 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/nats-io/jwt/v2 v2.5.8 h1:uvdSzwWiEGWGXf+0Q+70qv6AQdvcvxrv9hPM0RiPamE= github.com/nats-io/jwt/v2 v2.5.8/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= -github.com/nats-io/nats-server/v2 v2.10.22 h1:Yt63BGu2c3DdMoBZNcR6pjGQwk/asrKU7VX846ibxDA= -github.com/nats-io/nats-server/v2 v2.10.22/go.mod h1:X/m1ye9NYansUXYFrbcDwUi/blHkrgHh2rgCJaakonk= +github.com/nats-io/nats-server/v2 v2.10.23 h1:jvfb9cEi5h8UG6HkZgJGdn9f1UPaX3Dohk0PohEekJI= +github.com/nats-io/nats-server/v2 v2.10.23/go.mod h1:hMFnpDT2XUXsvHglABlFl/uroQCCOcW6X/0esW6GpBk= github.com/nats-io/nats.go v1.37.0 h1:07rauXbVnnJvv1gfIyghFEo6lUcYRY0WXc3x7x0vUxE= github.com/nats-io/nats.go v1.37.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= -github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= -github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= +github.com/nats-io/nkeys v0.4.8 h1:+wee30071y3vCZAYRsnrmIPaOe47A/SkK/UBDPdIV70= +github.com/nats-io/nkeys v0.4.8/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/ncw/swift v1.0.52/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= @@ -1266,8 +1266,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= -golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1388,8 +1388,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= -golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= -golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1437,8 +1437,8 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= -golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1526,8 +1526,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= @@ -1536,8 +1536,8 @@ golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/term v0.26.0 h1:WEQa6V3Gja/BhNxg540hBip/kkaYtRg3cxg4oXSw4AU= -golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1554,8 +1554,8 @@ golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= -golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1707,8 +1707,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/ google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI= google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0= google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg= -google.golang.org/api v0.210.0 h1:HMNffZ57OoZCRYSbdWVRoqOa8V8NIHLL0CzdBPLztWk= -google.golang.org/api v0.210.0/go.mod h1:B9XDZGnx2NtyjzVkOVTGrFSAVZgPcbedzKg/gTLwqBs= +google.golang.org/api v0.214.0 h1:h2Gkq07OYi6kusGOaT/9rnNljuXmqPnaig7WGPmKbwA= +google.golang.org/api v0.214.0/go.mod h1:bYPpLG8AyeMWwDU6NXoB00xC0DFkikVvd5MfwoxjLqE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1850,10 +1850,10 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 h1:ToEetK57OidYuqD4Q5w+vfEnPvPpuTwedCNVohYJfNk= google.golang.org/genproto v0.0.0-20241118233622-e639e219e697/go.mod h1:JJrvXBWRZaFMxBufik1a4RpFw4HhgVtBBWQeQgUj2cc= -google.golang.org/genproto/googleapis/api v0.0.0-20241113202542-65e8d215514f h1:M65LEviCfuZTfrfzwwEoxVtgvfkFkBUbFnRbxCXuXhU= -google.golang.org/genproto/googleapis/api v0.0.0-20241113202542-65e8d215514f/go.mod h1:Yo94eF2nj7igQt+TiJ49KxjIH8ndLYPZMIRSiRcEbg0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697 h1:LWZqQOEjDyONlF1H6afSWpAL/znlREo2tHfLoe+8LMA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= +google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 h1:pgr/4QbFyktUv9CtQ/Fq4gzEE6/Xs7iCXbktaGzLHbQ= +google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697/go.mod h1:+D9ySVjN8nY8YCVjc5O7PZDIdZporIDY3KaGfJunh88= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 h1:8ZmaLZE4XWrtU3MyClkYqqtl6Oegr3235h7jxsDyqCY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1895,8 +1895,8 @@ google.golang.org/grpc v1.52.3/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5v google.golang.org/grpc v1.53.0/go.mod h1:OnIrk0ipVdj4N5d9IUoFUx72/VlD7+jUsHwZgwSMQpw= google.golang.org/grpc v1.54.0/go.mod h1:PUSEXI6iWghWaB6lXM4knEgpJNu2qUcKfDtNci3EC2g= google.golang.org/grpc v1.56.3/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= -google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= -google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= +google.golang.org/grpc v1.67.2 h1:Lq11HW1nr5m4OYV+ZVy2BjOK78/zqnTx24vyDBP1JcQ= +google.golang.org/grpc v1.67.2/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a h1:UIpYSuWdWHSzjwcAFRLjKcPXFZVVLXGEM23W+NWqipw= google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a/go.mod h1:9i1T9n4ZinTUZGgzENMi8MDDgbGC5mqTS75JAv6xN3A= @@ -1917,8 +1917,8 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= -google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.0 h1:mjIs9gYtt56AzC4ZaffQuh88TZurBGhIJMBZGSxNerQ= +google.golang.org/protobuf v1.36.0/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/sdks/go/pkg/beam/core/core.go b/sdks/go/pkg/beam/core/core.go index 1b478f483077..a183ddf384ed 100644 --- a/sdks/go/pkg/beam/core/core.go +++ b/sdks/go/pkg/beam/core/core.go @@ -27,7 +27,7 @@ const ( // SdkName is the human readable name of the SDK for UserAgents. SdkName = "Apache Beam SDK for Go" // SdkVersion is the current version of the SDK. - SdkVersion = "2.62.0.dev" + SdkVersion = "2.63.0.dev" // DefaultDockerImage represents the associated image for this release. DefaultDockerImage = "apache/beam_go_sdk:" + SdkVersion diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go index 92c4d0a8f8cd..9f6f8a986a3f 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go @@ -261,22 +261,6 @@ func TestElementChan(t *testing.T) { return elms }, wantSum: 6, wantCount: 3, - }, { - name: "FillBufferThenAbortThenRead", - sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { - for i := 0; i < bufElements+2; i++ { - client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) - } - elms := openChan(ctx, t, c, timerID) - c.removeInstruction(instID) - - // These will be ignored - client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) - client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(2, false)}}) - client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) - return elms - }, - wantSum: bufElements, wantCount: bufElements, }, { name: "DataThenReaderThenLast", sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { @@ -389,18 +373,6 @@ func TestElementChan(t *testing.T) { return elms }, wantSum: 0, wantCount: 0, - }, { - name: "SomeTimersAndADataThenReaderThenCleanup", - sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { - client.Send(&fnpb.Elements{ - Timers: []*fnpb.Elements_Timers{timerElm(1, false), timerElm(2, true)}, - Data: []*fnpb.Elements_Data{dataElm(3, true)}, - }) - elms := openChan(ctx, t, c, timerID) - c.removeInstruction(instID) - return elms - }, - wantSum: 6, wantCount: 3, }, } for _, test := range tests { diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go index 7b8689f95112..380b6e2f31d1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -17,13 +17,17 @@ package engine import ( "bytes" + "cmp" "fmt" "log/slog" + "slices" + "sort" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "google.golang.org/protobuf/encoding/protowire" ) // StateData is a "union" between Bag state and MultiMap state to increase common code. @@ -42,6 +46,10 @@ type TimerKey struct { type TentativeData struct { Raw map[string][][]byte + // stateTypeLen is a map from LinkID to valueLen function for parsing data. + // Only used by OrderedListState, since Prism must manipulate these datavalues, + // which isn't expected, or a requirement of other state values. + stateTypeLen map[LinkID]func([]byte) int // state is a map from transformID + UserStateID, to window, to userKey, to datavalues. state map[LinkID]map[typex.Window]map[string]StateData // timers is a map from the Timer transform+family to the encoded timer. @@ -220,3 +228,92 @@ func (d *TentativeData) ClearMultimapKeysState(stateID LinkID, wKey, uKey []byte kmap[string(uKey)] = StateData{} slog.Debug("State() MultimapKeys.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey)) } + +// AppendOrderedListState appends the incoming timestamped data to the existing tentative data bundle. +// Assumes the data is TimestampedValue encoded, which has a BigEndian int64 suffixed to the data. +// This means we may always use the last 8 bytes to determine the value sorting. +// +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) AppendOrderedListState(stateID LinkID, wKey, uKey []byte, data []byte) { + kmap := d.appendState(stateID, wKey) + typeLen := d.stateTypeLen[stateID] + var datums [][]byte + + // We need to parse out all values individually for later sorting. + // + // OrderedListState is encoded as KVs with varint encoded millis followed by the value. + // This is not the standard TimestampValueCoder encoding, which + // uses a big-endian long as a suffix to the value. This is important since + // values may be concatenated, and we'll need to split them out out. + // + // The TentativeData.stateTypeLen is populated with a function to extract + // the length of a the next value, so we can skip through elements individually. + for i := 0; i < len(data); { + // Get the length of the VarInt for the timestamp. + _, tn := protowire.ConsumeVarint(data[i:]) + + // Get the length of the encoded value. + vn := typeLen(data[i+tn:]) + prev := i + i += tn + vn + datums = append(datums, data[prev:i]) + } + + s := StateData{Bag: append(kmap[string(uKey)].Bag, datums...)} + sort.SliceStable(s.Bag, func(i, j int) bool { + vi := s.Bag[i] + vj := s.Bag[j] + return compareTimestampSuffixes(vi, vj) + }) + kmap[string(uKey)] = s + slog.Debug("State() OrderedList.Append", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Any("NewData", s)) +} + +func compareTimestampSuffixes(vi, vj []byte) bool { + ims, _ := protowire.ConsumeVarint(vi) + jms, _ := protowire.ConsumeVarint(vj) + return (int64(ims)) < (int64(jms)) +} + +// GetOrderedListState available state from the tentative bundle data. +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) GetOrderedListState(stateID LinkID, wKey, uKey []byte, start, end int64) [][]byte { + winMap := d.state[stateID] + w := d.toWindow(wKey) + data := winMap[w][string(uKey)] + + lo, hi := findRange(data.Bag, start, end) + slog.Debug("State() OrderedList.Get", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", slog.Int64("start", start), slog.Int64("end", end)), slog.Group("outrange", slog.Int("lo", lo), slog.Int("hi", hi)), slog.Any("Data", data.Bag[lo:hi])) + return data.Bag[lo:hi] +} + +func cmpSuffix(vs [][]byte, target int64) func(i int) int { + return func(i int) int { + v := vs[i] + ims, _ := protowire.ConsumeVarint(v) + tvsbi := cmp.Compare(target, int64(ims)) + slog.Debug("cmpSuffix", "target", target, "bi", ims, "tvsbi", tvsbi) + return tvsbi + } +} + +func findRange(bag [][]byte, start, end int64) (int, int) { + lo, _ := sort.Find(len(bag), cmpSuffix(bag, start)) + hi, _ := sort.Find(len(bag), cmpSuffix(bag, end)) + return lo, hi +} + +func (d *TentativeData) ClearOrderedListState(stateID LinkID, wKey, uKey []byte, start, end int64) { + winMap := d.state[stateID] + w := d.toWindow(wKey) + kMap := winMap[w] + data := kMap[string(uKey)] + + lo, hi := findRange(data.Bag, start, end) + slog.Debug("State() OrderedList.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", slog.Int64("start", start), slog.Int64("end", end)), "lo", lo, "hi", hi, slog.Any("PreClearData", data.Bag)) + + cleared := slices.Delete(data.Bag, lo, hi) + // Zero the current entry to clear. + // Delete makes it difficult to delete the persisted stage state for the key. + kMap[string(uKey)] = StateData{Bag: cleared} +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go new file mode 100644 index 000000000000..1d0497104182 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package engine + +import ( + "bytes" + "encoding/binary" + "math" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/encoding/protowire" +) + +func TestCompareTimestampSuffixes(t *testing.T) { + t.Run("simple", func(t *testing.T) { + loI := int64(math.MinInt64) + hiI := int64(math.MaxInt64) + + loB := binary.BigEndian.AppendUint64(nil, uint64(loI)) + hiB := binary.BigEndian.AppendUint64(nil, uint64(hiI)) + + if compareTimestampSuffixes(loB, hiB) != (loI < hiI) { + t.Errorf("lo vs Hi%v < %v: bytes %v vs %v, %v %v", loI, hiI, loB, hiB, loI < hiI, compareTimestampSuffixes(loB, hiB)) + } + }) +} + +func TestOrderedListState(t *testing.T) { + time1 := protowire.AppendVarint(nil, 11) + time2 := protowire.AppendVarint(nil, 22) + time3 := protowire.AppendVarint(nil, 33) + time4 := protowire.AppendVarint(nil, 44) + time5 := protowire.AppendVarint(nil, 55) + + wKey := []byte{} // global window. + uKey := []byte("\u0007userkey") + linkID := LinkID{ + Transform: "dofn", + Local: "localStateName", + } + cc := func(a []byte, b ...byte) []byte { + return bytes.Join([][]byte{a, b}, []byte{}) + } + + t.Run("bool", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(_ []byte) int { + return 1 + }, + }, + } + + d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 1)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 1)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 1)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, 1), + cc(time2, 0), + cc(time3, 1), + cc(time4, 0), + cc(time5, 1), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList booleans \n%v", d) + } + + d.ClearOrderedListState(linkID, wKey, uKey, 12, 54) + got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want = [][]byte{ + cc(time1, 1), + cc(time5, 1), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList booleans, after clear\n%v", d) + } + }) + t.Run("float64", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(_ []byte) int { + return 8 + }, + }, + } + + d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 0, 0, 0, 0, 0, 0, 0, 1)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 0, 0, 0, 0, 0, 0, 1, 0)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 0, 0, 0, 0, 0, 1, 0, 0)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0, 0, 0, 0, 1, 0, 0, 0)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0, 0, 0, 1, 0, 0, 0, 0)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, 0, 0, 0, 0, 0, 0, 1, 0), + cc(time2, 0, 0, 0, 0, 1, 0, 0, 0), + cc(time3, 0, 0, 0, 0, 0, 1, 0, 0), + cc(time4, 0, 0, 0, 1, 0, 0, 0, 0), + cc(time5, 0, 0, 0, 0, 0, 0, 0, 1), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList float64s \n%v", d) + } + + d.ClearOrderedListState(linkID, wKey, uKey, 11, 12) + d.ClearOrderedListState(linkID, wKey, uKey, 33, 34) + d.ClearOrderedListState(linkID, wKey, uKey, 55, 56) + + got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want = [][]byte{ + cc(time2, 0, 0, 0, 0, 1, 0, 0, 0), + cc(time4, 0, 0, 0, 1, 0, 0, 0, 0), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList float64s, after clear \n%v", d) + } + }) + + t.Run("varint", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(b []byte) int { + _, n := protowire.ConsumeVarint(b) + return int(n) + }, + }, + } + + d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, protowire.AppendVarint(nil, 56)...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, protowire.AppendVarint(nil, 20067)...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, protowire.AppendVarint(nil, 7777777)...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, protowire.AppendVarint(nil, 424242)...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, protowire.AppendVarint(nil, 0)...)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, protowire.AppendVarint(nil, 424242)...), + cc(time2, protowire.AppendVarint(nil, 56)...), + cc(time3, protowire.AppendVarint(nil, 7777777)...), + cc(time4, protowire.AppendVarint(nil, 20067)...), + cc(time5, protowire.AppendVarint(nil, 0)...), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList int32 \n%v", d) + } + }) + t.Run("lp", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(b []byte) int { + l, n := protowire.ConsumeVarint(b) + return int(l) + n + }, + }, + } + + d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, []byte("\u0003one")...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, []byte("\u0003two")...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, []byte("\u0005three")...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, []byte("\u0004four")...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, []byte("\u0003one")...), + cc(time2, []byte("\u0003two")...), + cc(time3, []byte("\u0005three")...), + cc(time4, []byte("\u0004four")...), + cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList int32 \n%v", d) + } + }) + t.Run("lp_onecall", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(b []byte) int { + l, n := protowire.ConsumeVarint(b) + return int(l) + n + }, + }, + } + d.AppendOrderedListState(linkID, wKey, uKey, bytes.Join([][]byte{ + time5, []byte("\u0019FourHundredAndEleventyTwo"), + time3, []byte("\u0005three"), + time2, []byte("\u0003two"), + time1, []byte("\u0003one"), + time4, []byte("\u0004four"), + }, nil)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, []byte("\u0003one")...), + cc(time2, []byte("\u0003two")...), + cc(time3, []byte("\u0005three")...), + cc(time4, []byte("\u0004four")...), + cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList int32 \n%v", d) + } + }) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index 1739efdb742a..3cfcf9ef8c0e 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -184,10 +184,10 @@ type Config struct { // // Watermarks are advanced based on consumed input, except if the stage produces residuals. type ElementManager struct { - config Config + config Config + nextBundID func() string // Generates unique bundleIDs. Set in the Bundles method. - impulses set[string] // List of impulse stages. - stages map[string]*stageState // The state for each stage. + stages map[string]*stageState // The state for each stage. consumers map[string][]string // Map from pcollectionID to stageIDs that consumes them as primary input. sideConsumers map[string][]LinkID // Map from pcollectionID to the stage+transform+input that consumes them as side input. @@ -197,6 +197,7 @@ type ElementManager struct { refreshCond sync.Cond // refreshCond protects the following fields with it's lock, and unblocks bundle scheduling. inprogressBundles set[string] // Active bundleIDs changedStages set[string] // Stages that have changed and need their watermark refreshed. + injectedBundles []RunBundle // Represents ready to execute bundles prepared outside of the main loop, such as for onWindowExpiration, or for Triggers. livePending atomic.Int64 // An accessible live pending count. DEBUG USE ONLY pendingElements sync.WaitGroup // pendingElements counts all unprocessed elements in a job. Jobs with no pending elements terminate successfully. @@ -248,7 +249,6 @@ func (em *ElementManager) AddStage(ID string, inputIDs, outputIDs []string, side // so we must do it here. if len(inputIDs) == 0 { refreshes := singleSet(ss.ID) - em.addToTestStreamImpulseSet(refreshes) em.markStagesAsChanged(refreshes) } @@ -267,8 +267,20 @@ func (em *ElementManager) StageAggregates(ID string) { // StageStateful marks the given stage as stateful, which means elements are // processed by key. -func (em *ElementManager) StageStateful(ID string) { - em.stages[ID].stateful = true +func (em *ElementManager) StageStateful(ID string, stateTypeLen map[LinkID]func([]byte) int) { + ss := em.stages[ID] + ss.stateful = true + ss.stateTypeLen = stateTypeLen +} + +// StageOnWindowExpiration marks the given stage as stateful, which means elements are +// processed by key. +func (em *ElementManager) StageOnWindowExpiration(stageID string, timer StaticTimerID) { + ss := em.stages[stageID] + ss.onWindowExpiration = timer + ss.keysToExpireByWindow = map[typex.Window]set[string]{} + ss.inProgressExpiredWindows = map[typex.Window]int{} + ss.expiryWindowsByBundles = map[string]typex.Window{} } // StageProcessingTimeTimers indicates which timers are processingTime domain timers. @@ -305,22 +317,9 @@ func (em *ElementManager) Impulse(stageID string) { em.addPending(count) } refreshes := stage.updateWatermarks(em) - - em.addToTestStreamImpulseSet(refreshes) em.markStagesAsChanged(refreshes) } -// addToTestStreamImpulseSet adds to the set of stages to refresh on pipeline start. -// We keep this separate since impulses are synthetic. In a test stream driven pipeline -// these will need to be stimulated separately, to ensure the test stream has progressed. -func (em *ElementManager) addToTestStreamImpulseSet(refreshes set[string]) { - if em.impulses == nil { - em.impulses = refreshes - } else { - em.impulses.merge(refreshes) - } -} - type RunBundle struct { StageID string BundleID string @@ -338,6 +337,8 @@ func (rb RunBundle) LogValue() slog.Value { // The returned channel is closed when the context is canceled, or there are no pending elements // remaining. func (em *ElementManager) Bundles(ctx context.Context, upstreamCancelFn context.CancelCauseFunc, nextBundID func() string) <-chan RunBundle { + // Make it easier for injected bundles to get unique IDs. + em.nextBundID = nextBundID runStageCh := make(chan RunBundle) ctx, cancelFn := context.WithCancelCause(ctx) go func() { @@ -357,12 +358,6 @@ func (em *ElementManager) Bundles(ctx context.Context, upstreamCancelFn context. }() defer close(runStageCh) - // If we have a test stream, clear out existing changed stages, - // so the test stream can insert any elements it needs. - if em.testStreamHandler != nil { - em.changedStages = singleSet(em.testStreamHandler.ID) - } - for { em.refreshCond.L.Lock() // Check if processing time has advanced before the wait loop. @@ -370,8 +365,9 @@ func (em *ElementManager) Bundles(ctx context.Context, upstreamCancelFn context. changedByProcessingTime := em.processTimeEvents.AdvanceTo(emNow) em.changedStages.merge(changedByProcessingTime) - // If there are no changed stages or ready processing time events available, we wait until there are. - for len(em.changedStages)+len(changedByProcessingTime) == 0 { + // If there are no changed stages, ready processing time events, + // or injected bundles available, we wait until there are. + for len(em.changedStages)+len(changedByProcessingTime)+len(em.injectedBundles) == 0 { // Check to see if we must exit select { case <-ctx.Done(): @@ -386,6 +382,19 @@ func (em *ElementManager) Bundles(ctx context.Context, upstreamCancelFn context. changedByProcessingTime = em.processTimeEvents.AdvanceTo(emNow) em.changedStages.merge(changedByProcessingTime) } + // Run any injected bundles first. + for len(em.injectedBundles) > 0 { + rb := em.injectedBundles[0] + em.injectedBundles = em.injectedBundles[1:] + em.refreshCond.L.Unlock() + + select { + case <-ctx.Done(): + return + case runStageCh <- rb: + } + em.refreshCond.L.Lock() + } // We know there is some work we can do that may advance the watermarks, // refresh them, and see which stages have advanced. @@ -628,6 +637,12 @@ type Block struct { Transform, Family string } +// StaticTimerID represents the static user identifiers for a timer, +// in particular, the ID of the Transform, and the family for the timer. +type StaticTimerID struct { + TransformID, TimerFamily string +} + // StateForBundle retreives relevant state for the given bundle, WRT the data in the bundle. // // TODO(lostluck): Consider unifiying with InputForBundle, to reduce lock contention. @@ -635,7 +650,9 @@ func (em *ElementManager) StateForBundle(rb RunBundle) TentativeData { ss := em.stages[rb.StageID] ss.mu.Lock() defer ss.mu.Unlock() - var ret TentativeData + ret := TentativeData{ + stateTypeLen: ss.stateTypeLen, + } keys := ss.inprogressKeysByBundle[rb.BundleID] // TODO(lostluck): Also track windows per bundle, to reduce copying. if len(ss.state) > 0 { @@ -847,6 +864,19 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol } delete(stage.inprogressHoldsByBundle, rb.BundleID) + // Clean up OnWindowExpiration bundle accounting, so window state + // may be garbage collected. + if stage.expiryWindowsByBundles != nil { + win, ok := stage.expiryWindowsByBundles[rb.BundleID] + if ok { + stage.inProgressExpiredWindows[win] -= 1 + if stage.inProgressExpiredWindows[win] == 0 { + delete(stage.inProgressExpiredWindows, win) + } + delete(stage.expiryWindowsByBundles, rb.BundleID) + } + } + // If there are estimated output watermarks, set the estimated // output watermark for the stage. if len(residuals.MinOutputWatermarks) > 0 { @@ -1068,6 +1098,12 @@ type stageState struct { strat winStrat // Windowing Strategy for aggregation fireings. processingTimeTimersFamilies map[string]bool // Indicates which timer families use the processing time domain. + // onWindowExpiration management + onWindowExpiration StaticTimerID // The static ID of the OnWindowExpiration callback. + keysToExpireByWindow map[typex.Window]set[string] // Tracks all keys ever used with a window, so they may be expired. + inProgressExpiredWindows map[typex.Window]int // Tracks the number of bundles currently expiring these windows, so we don't prematurely garbage collect them. + expiryWindowsByBundles map[string]typex.Window // Tracks which bundle is handling which window, so the above map can be cleared. + mu sync.Mutex upstreamWatermarks sync.Map // watermark set from inputPCollection's parent. input mtime.Time // input watermark for the parallel input. @@ -1083,6 +1119,7 @@ type stageState struct { inprogressKeys set[string] // all keys that are assigned to bundles. inprogressKeysByBundle map[string]set[string] // bundle to key assignments. state map[LinkID]map[typex.Window]map[string]StateData // state data for this stage, from {tid, stateID} -> window -> userKey + stateTypeLen map[LinkID]func([]byte) int // map from state to a function that will produce the total length of a single value in bytes. // Accounting for handling watermark holds for timers. // We track the count of timers with the same hold, and clear it from @@ -1158,6 +1195,14 @@ func (ss *stageState) AddPending(newPending []element) int { timers: map[timerKey]timerTimes{}, } ss.pendingByKeys[string(e.keyBytes)] = dnt + if ss.keysToExpireByWindow != nil { + w, ok := ss.keysToExpireByWindow[e.window] + if !ok { + w = make(set[string]) + ss.keysToExpireByWindow[e.window] = w + } + w.insert(string(e.keyBytes)) + } } heap.Push(&dnt.elements, e) @@ -1555,48 +1600,143 @@ func (ss *stageState) updateWatermarks(em *ElementManager) set[string] { if minWatermarkHold < newOut { newOut = minWatermarkHold } - refreshes := set[string]{} + // If the newOut is smaller, then don't change downstream watermarks. + if newOut <= ss.output { + return nil + } + // If bigger, advance the output watermark - if newOut > ss.output { - ss.output = newOut - for _, outputCol := range ss.outputIDs { - consumers := em.consumers[outputCol] - - for _, sID := range consumers { - em.stages[sID].updateUpstreamWatermark(outputCol, ss.output) - refreshes.insert(sID) - } - // Inform side input consumers, but don't update the upstream watermark. - for _, sID := range em.sideConsumers[outputCol] { - refreshes.insert(sID.Global) - } - } - // Garbage collect state, timers and side inputs, for all windows - // that are before the new output watermark. - // They'll never be read in again. - for _, wins := range ss.sideInputs { - for win := range wins { - // TODO(#https://github.com/apache/beam/issues/31438): - // Adjust with AllowedLateness - // Clear out anything we've already used. - if win.MaxTimestamp() < newOut { - delete(wins, win) + preventDownstreamUpdate := ss.createOnWindowExpirationBundles(newOut, em) + + // Garbage collect state, timers and side inputs, for all windows + // that are before the new output watermark, if they aren't in progress + // of being expired. + // They'll never be read in again. + for _, wins := range ss.sideInputs { + for win := range wins { + // TODO(#https://github.com/apache/beam/issues/31438): + // Adjust with AllowedLateness + // Clear out anything we've already used. + if win.MaxTimestamp() < newOut { + // If the expiry is in progress, skip this window. + if ss.inProgressExpiredWindows[win] > 0 { + continue } + delete(wins, win) } } - for _, wins := range ss.state { - for win := range wins { - // TODO(#https://github.com/apache/beam/issues/31438): - // Adjust with AllowedLateness - if win.MaxTimestamp() < newOut { - delete(wins, win) + } + for _, wins := range ss.state { + for win := range wins { + // TODO(#https://github.com/apache/beam/issues/31438): + // Adjust with AllowedLateness + if win.MaxTimestamp() < newOut { + // If the expiry is in progress, skip collecting this window. + if ss.inProgressExpiredWindows[win] > 0 { + continue } + delete(wins, win) } } } + // If there are windows to expire, we don't update the output watermark yet. + if preventDownstreamUpdate { + return nil + } + + // Update this stage's output watermark, and then propagate that to downstream stages + refreshes := set[string]{} + ss.output = newOut + for _, outputCol := range ss.outputIDs { + consumers := em.consumers[outputCol] + + for _, sID := range consumers { + em.stages[sID].updateUpstreamWatermark(outputCol, ss.output) + refreshes.insert(sID) + } + // Inform side input consumers, but don't update the upstream watermark. + for _, sID := range em.sideConsumers[outputCol] { + refreshes.insert(sID.Global) + } + } return refreshes } +// createOnWindowExpirationBundles injects bundles when windows +// expire for all keys that were used in that window. Returns true if any +// bundles are created, which means that the window must not yet be garbage +// collected. +// +// Must be called within the stageState.mu's and the ElementManager.refreshCond +// critical sections. +func (ss *stageState) createOnWindowExpirationBundles(newOut mtime.Time, em *ElementManager) bool { + var preventDownstreamUpdate bool + for win, keys := range ss.keysToExpireByWindow { + // Check if the window has expired. + // TODO(#https://github.com/apache/beam/issues/31438): + // Adjust with AllowedLateness + if win.MaxTimestamp() >= newOut { + continue + } + // We can't advance the output watermark if there's garbage to collect. + preventDownstreamUpdate = true + // Hold off on garbage collecting data for these windows while these + // are in progress. + ss.inProgressExpiredWindows[win] += 1 + + // Produce bundle(s) for these keys and window, and inject them. + wm := win.MaxTimestamp() + rb := RunBundle{StageID: ss.ID, BundleID: "owe-" + em.nextBundID(), Watermark: wm} + + // Now we need to actually build the bundle. + var toProcess []element + busyKeys := set[string]{} + usedKeys := set[string]{} + for k := range keys { + if ss.inprogressKeys.present(k) { + busyKeys.insert(k) + continue + } + usedKeys.insert(k) + toProcess = append(toProcess, element{ + window: win, + timestamp: wm, + pane: typex.NoFiringPane(), + holdTimestamp: wm, + transform: ss.onWindowExpiration.TransformID, + family: ss.onWindowExpiration.TimerFamily, + sequence: 1, + keyBytes: []byte(k), + elmBytes: nil, + }) + } + em.addPending(len(toProcess)) + ss.watermarkHolds.Add(wm, 1) + ss.makeInProgressBundle( + func() string { return rb.BundleID }, + toProcess, + wm, + usedKeys, + map[mtime.Time]int{wm: 1}, + ) + ss.expiryWindowsByBundles[rb.BundleID] = win + + slog.Debug("OnWindowExpiration-Bundle Created", slog.Any("bundle", rb), slog.Any("usedKeys", usedKeys), slog.Any("window", win), slog.Any("toProcess", toProcess), slog.Any("busyKeys", busyKeys)) + // We're already in the refreshCond critical section. + // Insert that this is in progress here to avoid a race condition. + em.inprogressBundles.insert(rb.BundleID) + em.injectedBundles = append(em.injectedBundles, rb) + + // Remove the key accounting, or continue tracking which keys still need clearing. + if len(busyKeys) == 0 { + delete(ss.keysToExpireByWindow, win) + } else { + ss.keysToExpireByWindow[win] = busyKeys + } + } + return preventDownstreamUpdate +} + // bundleReady returns the maximum allowed watermark for this stage, and whether // it's permitted to execute by side inputs. func (ss *stageState) bundleReady(em *ElementManager, emNow mtime.Time) (mtime.Time, bool, bool) { diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go index d5904b13fb88..0d7da5ea163f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "sync/atomic" "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -524,3 +525,162 @@ func TestElementManager(t *testing.T) { } }) } + +func TestElementManager_OnWindowExpiration(t *testing.T) { + t.Run("createOnWindowExpirationBundles", func(t *testing.T) { + // Unlike the other tests above, we synthesize the input configuration, + em := NewElementManager(Config{}) + var instID uint64 + em.nextBundID = func() string { + return fmt.Sprintf("inst%03d", atomic.AddUint64(&instID, 1)) + } + em.AddStage("impulse", nil, []string{"input"}, nil) + em.AddStage("dofn", []string{"input"}, nil, nil) + onWE := StaticTimerID{ + TransformID: "dofn1", + TimerFamily: "onWinExp", + } + em.StageOnWindowExpiration("dofn", onWE) + em.Impulse("impulse") + + stage := em.stages["dofn"] + stage.pendingByKeys = map[string]*dataAndTimers{} + stage.inprogressKeys = set[string]{} + + validateInProgressExpiredWindows := func(t *testing.T, win typex.Window, want int) { + t.Helper() + if got := stage.inProgressExpiredWindows[win]; got != want { + t.Errorf("stage.inProgressExpiredWindows[%v] = %v, want %v", win, got, want) + } + } + validateSideBundles := func(t *testing.T, keys set[string]) { + t.Helper() + if len(em.injectedBundles) == 0 { + t.Errorf("no injectedBundles exist when checking keys: %v", keys) + } + // Check that all keys are marked as in progress + for k := range keys { + if !stage.inprogressKeys.present(k) { + t.Errorf("key %q not marked as in progress", k) + } + } + + bundleID := "" + sideBundles: + for _, rb := range em.injectedBundles { + // find that a side channel bundle exists with these keys. + bkeys := stage.inprogressKeysByBundle[rb.BundleID] + if len(bkeys) != len(keys) { + continue sideBundles + } + for k := range keys { + if !bkeys.present(k) { + continue sideBundles + } + } + bundleID = rb.BundleID + break + } + if bundleID == "" { + t.Errorf("no bundle found with all the given keys: %v: bundles: %v keysByBundle: %v", keys, em.injectedBundles, stage.inprogressKeysByBundle) + } + } + + newOut := mtime.EndOfGlobalWindowTime + // No windows exist, so no side channel bundles should be set. + if got, want := stage.createOnWindowExpirationBundles(newOut, em), false; got != want { + t.Errorf("createOnWindowExpirationBundles(%v) = %v, want %v", newOut, got, want) + } + // Validate that no side channel bundles were created. + if got, want := len(stage.inProgressExpiredWindows), 0; got != want { + t.Errorf("len(stage.inProgressExpiredWindows) = %v, want %v", got, want) + } + if got, want := len(em.injectedBundles), 0; got != want { + t.Errorf("len(em.injectedBundles) = %v, want %v", got, want) + } + + // Configure a few conditions to validate in the call. + // Each window is in it's own bundle, all are in the same bundle. + // Bundle 1 + expiredWindow1 := window.IntervalWindow{Start: 0, End: newOut - 1} + + akey := "\u0004key1" + keys1 := singleSet(akey) + stage.keysToExpireByWindow[expiredWindow1] = keys1 + // Bundle 2 + expiredWindow2 := window.IntervalWindow{Start: 1, End: newOut - 1} + keys2 := singleSet("\u0004key2") + keys2.insert("\u0004key3") + keys2.insert("\u0004key4") + stage.keysToExpireByWindow[expiredWindow2] = keys2 + + // We should never see this key and window combination, as the window is + // not yet expired. + liveWindow := window.IntervalWindow{Start: 2, End: newOut + 1} + stage.keysToExpireByWindow[liveWindow] = singleSet("\u0010keyNotSeen") + + if got, want := stage.createOnWindowExpirationBundles(newOut, em), true; got != want { + t.Errorf("createOnWindowExpirationBundles(%v) = %v, want %v", newOut, got, want) + } + + // We should only see 2 injectedBundles at this point. + if got, want := len(em.injectedBundles), 2; got != want { + t.Errorf("len(em.injectedBundles) = %v, want %v", got, want) + } + + validateInProgressExpiredWindows(t, expiredWindow1, 1) + validateInProgressExpiredWindows(t, expiredWindow2, 1) + validateSideBundles(t, keys1) + validateSideBundles(t, keys2) + + // Bundle 3 + expiredWindow3 := window.IntervalWindow{Start: 3, End: newOut - 1} + keys3 := singleSet(akey) // We shouldn't see this key, since it's in progress. + keys3.insert("\u0004key5") // We should see this key since it isn't. + stage.keysToExpireByWindow[expiredWindow3] = keys3 + + if got, want := stage.createOnWindowExpirationBundles(newOut, em), true; got != want { + t.Errorf("createOnWindowExpirationBundles(%v) = %v, want %v", newOut, got, want) + } + + // We should see 3 injectedBundles at this point. + if got, want := len(em.injectedBundles), 3; got != want { + t.Errorf("len(em.injectedBundles) = %v, want %v", got, want) + } + + validateInProgressExpiredWindows(t, expiredWindow1, 1) + validateInProgressExpiredWindows(t, expiredWindow2, 1) + validateInProgressExpiredWindows(t, expiredWindow3, 1) + validateSideBundles(t, keys1) + validateSideBundles(t, keys2) + validateSideBundles(t, singleSet("\u0004key5")) + + // remove key1 from "inprogress keys", and the associated bundle. + stage.inprogressKeys.remove(akey) + delete(stage.inProgressExpiredWindows, expiredWindow1) + for bundID, bkeys := range stage.inprogressKeysByBundle { + if bkeys.present(akey) { + t.Logf("bundID: %v, bkeys: %v, keyByBundle: %v", bundID, bkeys, stage.inprogressKeysByBundle) + delete(stage.inprogressKeysByBundle, bundID) + win := stage.expiryWindowsByBundles[bundID] + delete(stage.expiryWindowsByBundles, bundID) + if win != expiredWindow1 { + t.Fatalf("Unexpected window: got %v, want %v", win, expiredWindow1) + } + break + } + } + + // Now we should get another bundle for expiredWindow3, and have none for expiredWindow1 + if got, want := stage.createOnWindowExpirationBundles(newOut, em), true; got != want { + t.Errorf("createOnWindowExpirationBundles(%v) = %v, want %v", newOut, got, want) + } + + validateInProgressExpiredWindows(t, expiredWindow1, 0) + validateInProgressExpiredWindows(t, expiredWindow2, 1) + validateInProgressExpiredWindows(t, expiredWindow3, 2) + validateSideBundles(t, keys1) // Should still have this key present, but with a different bundle. + validateSideBundles(t, keys2) + validateSideBundles(t, singleSet("\u0004key5")) // still exist.. + }) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go b/sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go index ed3df75fd8ef..533cd5a0fc40 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go @@ -142,10 +142,8 @@ func (ts *testStreamHandler) UpdateHold(em *ElementManager, newHold mtime.Time) ts.currentHold = newHold ss.watermarkHolds.Add(ts.currentHold, 1) - // kick the TestStream and Impulse stages too. + // kick the TestStream stage to ensure downstream watermark propagation. kick := singleSet(ts.ID) - kick.merge(em.impulses) - // This executes under the refreshCond lock, so we can't call em.addRefreshes. em.changedStages.merge(kick) em.refreshCond.Broadcast() diff --git a/sdks/go/pkg/beam/runners/prism/internal/environments.go b/sdks/go/pkg/beam/runners/prism/internal/environments.go index 2f960a04f0cb..be4809f5e2f6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/environments.go +++ b/sdks/go/pkg/beam/runners/prism/internal/environments.go @@ -147,7 +147,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock ccr, err := cli.ContainerCreate(ctx, &container.Config{ Image: dp.GetContainerImage(), Cmd: []string{ - fmt.Sprintf("--id=%v-%v", wk.JobKey, wk.Env), + fmt.Sprintf("--id=%v", wk.ID), fmt.Sprintf("--control_endpoint=%v", wk.Endpoint()), fmt.Sprintf("--artifact_endpoint=%v", artifactEndpoint), fmt.Sprintf("--provision_endpoint=%v", wk.Endpoint()), diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index 614edee47721..d41c3cd9c75c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -53,13 +53,24 @@ func RunPipeline(j *jobservices.Job) { envs := j.Pipeline.GetComponents().GetEnvironments() wks := map[string]*worker.W{} for envID := range envs { - wk, err := makeWorker(envID, j) - if err != nil { - j.Failed(err) + wk := j.MakeWorker(envID) + wks[envID] = wk + if err := runEnvironment(j.RootCtx, j, envID, wk); err != nil { + j.Failed(fmt.Errorf("failed to start environment %v for job %v: %w", envID, j, err)) return } - wks[envID] = wk + // Check for connection succeeding after we've created the environment successfully. + timeout := 1 * time.Minute + time.AfterFunc(timeout, func() { + if wk.Connected() || wk.Stopped() { + return + } + err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, wk.Endpoint(), timeout) + j.Failed(err) + j.CancelFn(err) + }) } + // When this function exits, we cancel the context to clear // any related job resources. defer func() { @@ -86,33 +97,6 @@ func RunPipeline(j *jobservices.Job) { j.Done() } -// makeWorker creates a worker for that environment. -func makeWorker(env string, j *jobservices.Job) (*worker.W, error) { - wk := worker.New(j.String()+"_"+env, env) - - wk.EnvPb = j.Pipeline.GetComponents().GetEnvironments()[env] - wk.PipelineOptions = j.PipelineOptions() - wk.JobKey = j.JobKey() - wk.ArtifactEndpoint = j.ArtifactEndpoint() - - go wk.Serve() - - if err := runEnvironment(j.RootCtx, j, env, wk); err != nil { - return nil, fmt.Errorf("failed to start environment %v for job %v: %w", env, j, err) - } - // Check for connection succeeding after we've created the environment successfully. - timeout := 1 * time.Minute - time.AfterFunc(timeout, func() { - if wk.Connected() || wk.Stopped() { - return - } - err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, wk.Endpoint(), timeout) - j.Failed(err) - j.CancelFn(err) - }) - return wk, nil -} - type transformExecuter interface { ExecuteUrns() []string ExecuteTransform(stageID, tid string, t *pipepb.PTransform, comps *pipepb.Components, watermark mtime.Time, data [][]byte) *worker.B @@ -285,7 +269,6 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic elms = append(elms, engine.TestStreamElement{Encoded: mayLP(e.GetEncodedElement()), EventTime: mtime.Time(e.GetTimestamp())}) } tsb.AddElementEvent(ev.ElementEvent.GetTag(), elms) - ev.ElementEvent.GetTag() case *pipepb.TestStreamPayload_Event_WatermarkEvent: tsb.AddWatermarkEvent(ev.WatermarkEvent.GetTag(), mtime.Time(ev.WatermarkEvent.GetNewWatermark())) case *pipepb.TestStreamPayload_Event_ProcessingTimeEvent: @@ -316,7 +299,11 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic sort.Strings(outputs) em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs) if stage.stateful { - em.StageStateful(stage.ID) + em.StageStateful(stage.ID, stage.stateTypeLen) + } + if stage.onWindowExpiration.TimerFamily != "" { + slog.Debug("OnWindowExpiration", slog.String("stage", stage.ID), slog.Any("values", stage.onWindowExpiration)) + em.StageOnWindowExpiration(stage.ID, stage.onWindowExpiration) } if len(stage.processingTimeTimers) > 0 { em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index deef259a99d1..4be64e5a9c80 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -38,6 +38,7 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "google.golang.org/protobuf/types/known/structpb" ) @@ -45,6 +46,7 @@ var supportedRequirements = map[string]struct{}{ urns.RequirementSplittableDoFn: {}, urns.RequirementStatefulProcessing: {}, urns.RequirementBundleFinalization: {}, + urns.RequirementOnWindowExpiration: {}, } // TODO, move back to main package, and key off of executor handlers? @@ -92,6 +94,7 @@ type Job struct { Logger *slog.Logger metrics metricsStore + mw *worker.MultiplexW } func (j *Job) ArtifactEndpoint() string { @@ -197,3 +200,14 @@ func (j *Job) Failed(err error) { j.sendState(jobpb.JobState_FAILED) j.CancelFn(fmt.Errorf("jobFailed %v: %w", j, err)) } + +// MakeWorker instantiates a worker.W populating environment and pipeline data from the Job. +func (j *Job) MakeWorker(env string) *worker.W { + wk := j.mw.MakeWorker(j.String()+"_"+env, env) + wk.EnvPb = j.Pipeline.GetComponents().GetEnvironments()[env] + wk.PipelineOptions = j.PipelineOptions() + wk.JobKey = j.JobKey() + wk.ArtifactEndpoint = j.ArtifactEndpoint() + + return wk +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index a2840760bf7a..b9a28e4bc652 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -94,6 +94,7 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * }, Logger: s.logger, // TODO substitute with a configured logger. artifactEndpoint: s.Endpoint(), + mw: s.mw, } // Stop the idle timer when a new job appears. if idleTimer := s.idleTimer.Load(); idleTimer != nil { @@ -174,7 +175,8 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * // Validate all the state features for _, spec := range pardo.GetStateSpecs() { isStateful = true - check("StateSpec.Protocol.Urn", spec.GetProtocol().GetUrn(), urns.UserStateBag, urns.UserStateMultiMap) + check("StateSpec.Protocol.Urn", spec.GetProtocol().GetUrn(), + urns.UserStateBag, urns.UserStateMultiMap, urns.UserStateOrderedList) } // Validate all the timer features for _, spec := range pardo.GetTimerFamilySpecs() { @@ -182,8 +184,6 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * check("TimerFamilySpecs.TimeDomain.Urn", spec.GetTimeDomain(), pipepb.TimeDomain_EVENT_TIME, pipepb.TimeDomain_PROCESSING_TIME) } - check("OnWindowExpirationTimerFamily", pardo.GetOnWindowExpirationTimerFamilySpec(), "") // Unsupported for now. - // Check for a stateful SDF and direct user to https://github.com/apache/beam/issues/32139 if pardo.GetRestrictionCoderId() != "" && isStateful { check("Splittable+Stateful DoFn", "See https://github.com/apache/beam/issues/32139 for information.", "") diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index bdfe2aff2dd4..fb55fc54bf93 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -28,6 +28,7 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "google.golang.org/grpc" ) @@ -60,6 +61,8 @@ type Server struct { // Artifact hack artifacts map[string][]byte + + mw *worker.MultiplexW } // NewServer acquires the indicated port. @@ -82,6 +85,9 @@ func NewServer(port int, execute func(*Job)) *Server { jobpb.RegisterJobServiceServer(s.server, s) jobpb.RegisterArtifactStagingServiceServer(s.server, s) jobpb.RegisterArtifactRetrievalServiceServer(s.server, s) + + s.mw = worker.NewMultiplexW(lis, s.server, s.logger) + return s } diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index dceaa9ab8fcb..2048d32e4ad4 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -445,10 +445,15 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil { return fmt.Errorf("unable to decode ParDoPayload for %v", link.Transform) } - stg.finalize = pardo.RequestsFinalization + if pardo.GetRequestsFinalization() { + stg.finalize = true + } if len(pardo.GetTimerFamilySpecs())+len(pardo.GetStateSpecs())+len(pardo.GetOnWindowExpirationTimerFamilySpec()) > 0 { stg.stateful = true } + if pardo.GetOnWindowExpirationTimerFamilySpec() != "" { + stg.onWindowExpiration = engine.StaticTimerID{TransformID: link.Transform, TimerFamily: pardo.GetOnWindowExpirationTimerFamilySpec()} + } sis = pardo.GetSideInputs() } if _, ok := sis[link.Local]; ok { diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index 9f00c22789b6..e1e942a06f0c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -35,6 +35,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" ) @@ -57,20 +58,26 @@ type link struct { // account, but all serialization boundaries remain since the pcollections // would continue to get serialized. type stage struct { - ID string - transforms []string - primaryInput string // PCollection used as the parallel input. - outputs []link // PCollections that must escape this stage. - sideInputs []engine.LinkID // Non-parallel input PCollections and their consumers - internalCols []string // PCollections that escape. Used for precise coder sending. - envID string - finalize bool - stateful bool + ID string + transforms []string + primaryInput string // PCollection used as the parallel input. + outputs []link // PCollections that must escape this stage. + sideInputs []engine.LinkID // Non-parallel input PCollections and their consumers + internalCols []string // PCollections that escape. Used for precise coder sending. + envID string + finalize bool + stateful bool + onWindowExpiration engine.StaticTimerID + // hasTimers indicates the transform+timerfamily pairs that need to be waited on for // the stage to be considered complete. - hasTimers []struct{ Transform, TimerFamily string } + hasTimers []engine.StaticTimerID processingTimeTimers map[string]bool + // stateTypeLen maps state values to encoded lengths for the type. + // Only used for OrderedListState which must manipulate individual state datavalues. + stateTypeLen map[engine.LinkID]func([]byte) int + exe transformExecuter inputTransformID string inputInfo engine.PColInfo @@ -436,6 +443,38 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng rewriteCoder(&s.SetSpec.ElementCoderId) case *pipepb.StateSpec_OrderedListSpec: rewriteCoder(&s.OrderedListSpec.ElementCoderId) + // Add the length determination helper for OrderedList state values. + if stg.stateTypeLen == nil { + stg.stateTypeLen = map[engine.LinkID]func([]byte) int{} + } + linkID := engine.LinkID{ + Transform: tid, + Local: stateID, + } + var fn func([]byte) int + switch v := coders[s.OrderedListSpec.GetElementCoderId()]; v.GetSpec().GetUrn() { + case urns.CoderBool: + fn = func(_ []byte) int { + return 1 + } + case urns.CoderDouble: + fn = func(_ []byte) int { + return 8 + } + case urns.CoderVarInt: + fn = func(b []byte) int { + _, n := protowire.ConsumeVarint(b) + return int(n) + } + case urns.CoderLengthPrefix, urns.CoderBytes, urns.CoderStringUTF8: + fn = func(b []byte) int { + l, n := protowire.ConsumeVarint(b) + return int(l) + n + } + default: + rewriteErr = fmt.Errorf("unknown coder used for ordered list state after re-write id: %v coder: %v, for state %v for transform %v in stage %v", s.OrderedListSpec.GetElementCoderId(), v, stateID, tid, stg.ID) + } + stg.stateTypeLen[linkID] = fn case *pipepb.StateSpec_CombiningSpec: rewriteCoder(&s.CombiningSpec.AccumulatorCoderId) case *pipepb.StateSpec_MapSpec: @@ -452,7 +491,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng } } for timerID, v := range pardo.GetTimerFamilySpecs() { - stg.hasTimers = append(stg.hasTimers, struct{ Transform, TimerFamily string }{Transform: tid, TimerFamily: timerID}) + stg.hasTimers = append(stg.hasTimers, engine.StaticTimerID{TransformID: tid, TimerFamily: timerID}) if v.TimeDomain == pipepb.TimeDomain_PROCESSING_TIME { if stg.processingTimeTimers == nil { stg.processingTimeTimers = map[string]bool{} diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go index 5312fd799c89..12e62ef84a81 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go +++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go @@ -95,8 +95,9 @@ var ( SideInputMultiMap = siUrn(pipepb.StandardSideInputTypes_MULTIMAP) // UserState kinds - UserStateBag = usUrn(pipepb.StandardUserStateTypes_BAG) - UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP) + UserStateBag = usUrn(pipepb.StandardUserStateTypes_BAG) + UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP) + UserStateOrderedList = usUrn(pipepb.StandardUserStateTypes_ORDERED_LIST) // WindowsFns WindowFnGlobal = quickUrn(pipepb.GlobalWindowsPayload_PROPERTIES) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index 83ad1bda9841..14cd84aef821 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -42,7 +42,7 @@ type B struct { InputTransformID string Input []*engine.Block // Data and Timers for this bundle. EstimatedInputElements int - HasTimers []struct{ Transform, TimerFamily string } // Timer streams to terminate. + HasTimers []engine.StaticTimerID // Timer streams to terminate. // IterableSideInputData is a map from transformID + inputID, to window, to data. IterableSideInputData map[SideInputKey]map[typex.Window][][]byte @@ -190,7 +190,7 @@ func (b *B) ProcessOn(ctx context.Context, wk *W) <-chan struct{} { for _, tid := range b.HasTimers { timers = append(timers, &fnpb.Elements_Timers{ InstructionId: b.InstID, - TransformId: tid.Transform, + TransformId: tid.TransformID, TimerFamilyId: tid.TimerFamily, IsLast: true, }) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go index 161fb199ce96..08d30f67e445 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go @@ -25,7 +25,7 @@ import ( ) func TestBundle_ProcessOn(t *testing.T) { - wk := New("test", "testEnv") + wk := newWorker() b := &B{ InstID: "testInst", PBDID: "testPBDID", diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index c2c988aa097f..b4133b0332a6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -23,11 +23,9 @@ import ( "fmt" "io" "log/slog" - "math" "net" "sync" "sync/atomic" - "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -38,6 +36,7 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -55,16 +54,14 @@ type W struct { fnpb.UnimplementedBeamFnLoggingServer fnpb.UnimplementedProvisionServiceServer + parentPool *MultiplexW + ID, Env string JobKey, ArtifactEndpoint string EnvPb *pipepb.Environment PipelineOptions *structpb.Struct - // Server management - lis net.Listener - server *grpc.Server - // These are the ID sources inst uint64 connected, stopped atomic.Bool @@ -82,45 +79,8 @@ type controlResponder interface { Respond(*fnpb.InstructionResponse) } -// New starts the worker server components of FnAPI Execution. -func New(id, env string) *W { - lis, err := net.Listen("tcp", ":0") - if err != nil { - panic(fmt.Sprintf("failed to listen: %v", err)) - } - opts := []grpc.ServerOption{ - grpc.MaxRecvMsgSize(math.MaxInt32), - } - wk := &W{ - ID: id, - Env: env, - lis: lis, - server: grpc.NewServer(opts...), - - InstReqs: make(chan *fnpb.InstructionRequest, 10), - DataReqs: make(chan *fnpb.Elements, 10), - StoppedChan: make(chan struct{}), - - activeInstructions: make(map[string]controlResponder), - Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), - } - slog.Debug("Serving Worker components", slog.String("endpoint", wk.Endpoint())) - fnpb.RegisterBeamFnControlServer(wk.server, wk) - fnpb.RegisterBeamFnDataServer(wk.server, wk) - fnpb.RegisterBeamFnLoggingServer(wk.server, wk) - fnpb.RegisterBeamFnStateServer(wk.server, wk) - fnpb.RegisterProvisionServiceServer(wk.server, wk) - return wk -} - func (wk *W) Endpoint() string { - _, port, _ := net.SplitHostPort(wk.lis.Addr().String()) - return fmt.Sprintf("localhost:%v", port) -} - -// Serve serves on the started listener. Blocks. -func (wk *W) Serve() { - wk.server.Serve(wk.lis) + return wk.parentPool.endpoint } func (wk *W) String() string { @@ -154,16 +114,7 @@ func (wk *W) shutdown() { // Stop the GRPC server. func (wk *W) Stop() { wk.shutdown() - - // Give the SDK side 5 seconds to gracefully stop, before - // hard stopping all RPCs. - tim := time.AfterFunc(5*time.Second, func() { - wk.server.Stop() - }) - wk.server.GracefulStop() - tim.Stop() - - wk.lis.Close() + wk.parentPool.delete(wk) slog.Debug("stopped", "worker", wk) } @@ -554,6 +505,11 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case *fnpb.StateKey_MultimapKeysUserState_: mmkey := key.GetMultimapKeysUserState() data = b.OutputData.GetMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey()) + case *fnpb.StateKey_OrderedListUserState_: + olkey := key.GetOrderedListUserState() + data = b.OutputData.GetOrderedListState( + engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()}, + olkey.GetWindow(), olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd()) default: panic(fmt.Sprintf("unsupported StateKey Get type: %T: %v", key.GetType(), prototext.Format(key))) } @@ -578,6 +534,11 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case *fnpb.StateKey_MultimapUserState_: mmkey := key.GetMultimapUserState() b.OutputData.AppendMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey(), req.GetAppend().GetData()) + case *fnpb.StateKey_OrderedListUserState_: + olkey := key.GetOrderedListUserState() + b.OutputData.AppendOrderedListState( + engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()}, + olkey.GetWindow(), olkey.GetKey(), req.GetAppend().GetData()) default: panic(fmt.Sprintf("unsupported StateKey Append type: %T: %v", key.GetType(), prototext.Format(key))) } @@ -601,6 +562,10 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case *fnpb.StateKey_MultimapKeysUserState_: mmkey := key.GetMultimapUserState() b.OutputData.ClearMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey()) + case *fnpb.StateKey_OrderedListUserState_: + olkey := key.GetOrderedListUserState() + b.OutputData.ClearOrderedListState(engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()}, + olkey.GetWindow(), olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd()) default: panic(fmt.Sprintf("unsupported StateKey Clear type: %T: %v", key.GetType(), prototext.Format(key))) } @@ -696,3 +661,131 @@ func (wk *W) MonitoringMetadata(ctx context.Context, unknownIDs []string) *fnpb. }, }).GetMonitoringInfos() } + +// MultiplexW forwards FnAPI gRPC requests to W it manages in an in-memory pool. +type MultiplexW struct { + fnpb.UnimplementedBeamFnControlServer + fnpb.UnimplementedBeamFnDataServer + fnpb.UnimplementedBeamFnStateServer + fnpb.UnimplementedBeamFnLoggingServer + fnpb.UnimplementedProvisionServiceServer + + mu sync.Mutex + endpoint string + logger *slog.Logger + pool map[string]*W +} + +// NewMultiplexW instantiates a new FnAPI server for multiplexing FnAPI requests to a W. +func NewMultiplexW(lis net.Listener, g *grpc.Server, logger *slog.Logger) *MultiplexW { + _, p, _ := net.SplitHostPort(lis.Addr().String()) + mw := &MultiplexW{ + endpoint: "localhost:" + p, + logger: logger, + pool: make(map[string]*W), + } + + fnpb.RegisterBeamFnControlServer(g, mw) + fnpb.RegisterBeamFnDataServer(g, mw) + fnpb.RegisterBeamFnLoggingServer(g, mw) + fnpb.RegisterBeamFnStateServer(g, mw) + fnpb.RegisterProvisionServiceServer(g, mw) + + return mw +} + +// MakeWorker creates and registers a W, assigning id and env to W.ID and W.Env, respectively, associating W.ID +// to *W for later lookup. MultiplexW expects FnAPI gRPC requests to contain a matching 'worker_id' in its context +// metadata. A gRPC client should use the grpcx.WriteWorkerID helper method prior to sending the request. +func (mw *MultiplexW) MakeWorker(id, env string) *W { + mw.mu.Lock() + defer mw.mu.Unlock() + w := &W{ + ID: id, + Env: env, + + InstReqs: make(chan *fnpb.InstructionRequest, 10), + DataReqs: make(chan *fnpb.Elements, 10), + StoppedChan: make(chan struct{}), + + activeInstructions: make(map[string]controlResponder), + Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), + parentPool: mw, + } + mw.pool[id] = w + return w +} + +func (mw *MultiplexW) GetProvisionInfo(ctx context.Context, req *fnpb.GetProvisionInfoRequest) (*fnpb.GetProvisionInfoResponse, error) { + return handleUnary(mw, ctx, req, (*W).GetProvisionInfo) +} + +func (mw *MultiplexW) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { + return handleStream(mw, stream.Context(), stream, (*W).Logging) +} + +func (mw *MultiplexW) GetProcessBundleDescriptor(ctx context.Context, req *fnpb.GetProcessBundleDescriptorRequest) (*fnpb.ProcessBundleDescriptor, error) { + return handleUnary(mw, ctx, req, (*W).GetProcessBundleDescriptor) +} + +func (mw *MultiplexW) Control(ctrl fnpb.BeamFnControl_ControlServer) error { + return handleStream(mw, ctrl.Context(), ctrl, (*W).Control) +} + +func (mw *MultiplexW) Data(data fnpb.BeamFnData_DataServer) error { + return handleStream(mw, data.Context(), data, (*W).Data) +} + +func (mw *MultiplexW) State(state fnpb.BeamFnState_StateServer) error { + return handleStream(mw, state.Context(), state, (*W).State) +} + +func (mw *MultiplexW) MonitoringMetadata(ctx context.Context, unknownIDs []string) *fnpb.MonitoringInfosMetadataResponse { + mw.mu.Lock() + defer mw.mu.Unlock() + w, err := mw.workerFromMetadataCtx(ctx) + if err != nil { + mw.logger.Error(err.Error()) + return nil + } + return w.MonitoringMetadata(ctx, unknownIDs) +} + +func (mw *MultiplexW) workerFromMetadataCtx(ctx context.Context) (*W, error) { + mw.mu.Lock() + defer mw.mu.Unlock() + id, err := grpcx.ReadWorkerID(ctx) + if err != nil { + return nil, err + } + if id == "" { + return nil, fmt.Errorf("worker_id read from context metadata is an empty string") + } + w, ok := mw.pool[id] + if !ok { + return nil, fmt.Errorf("worker_id: '%s' read from context metadata but not registered in worker pool", id) + } + return w, nil +} + +func (mw *MultiplexW) delete(w *W) { + mw.mu.Lock() + defer mw.mu.Unlock() + delete(mw.pool, w.ID) +} + +func handleUnary[Request any, Response any, Method func(*W, context.Context, *Request) (*Response, error)](mw *MultiplexW, ctx context.Context, req *Request, m Method) (*Response, error) { + w, err := mw.workerFromMetadataCtx(ctx) + if err != nil { + return nil, err + } + return m(w, ctx, req) +} + +func handleStream[Stream any, Method func(*W, Stream) error](mw *MultiplexW, ctx context.Context, stream Stream, m Method) error { + w, err := mw.workerFromMetadataCtx(ctx) + if err != nil { + return err + } + return m(w, stream) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go index 469e0e2f3d83..a0cf577fbdba 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -18,34 +18,88 @@ package worker import ( "bytes" "context" + "log/slog" "net" "sort" "sync" "testing" "time" - "github.com/google/go-cmp/cmp" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" ) -func TestWorker_New(t *testing.T) { - w := New("test", "testEnv") +func TestMultiplexW_MakeWorker(t *testing.T) { + w := newWorker() + if w.parentPool == nil { + t.Errorf("MakeWorker instantiated W with a nil reference to MultiplexW") + } if got, want := w.ID, "test"; got != want { - t.Errorf("New(%q) = %v, want %v", want, got, want) + t.Errorf("MakeWorker(%q) = %v, want %v", want, got, want) + } + got, ok := w.parentPool.pool[w.ID] + if !ok || got == nil { + t.Errorf("MakeWorker(%q) not registered in worker pool %v", w.ID, w.parentPool.pool) + } +} + +func TestMultiplexW_workerFromMetadataCtx(t *testing.T) { + for _, tt := range []struct { + name string + ctx context.Context + want *W + wantErr string + }{ + { + name: "empty ctx metadata", + ctx: context.Background(), + wantErr: "failed to read metadata from context", + }, + { + name: "worker_id empty", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "")), + wantErr: "worker_id read from context metadata is an empty string", + }, + { + name: "mismatched worker_id", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "doesn't exist")), + wantErr: "worker_id: 'doesn't exist' read from context metadata but not registered in worker pool", + }, + { + name: "matched worker_id", + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("worker_id", "test")), + want: &W{ID: "test"}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + w := newWorker() + got, err := w.parentPool.workerFromMetadataCtx(tt.ctx) + if err != nil && err.Error() != tt.wantErr { + t.Errorf("workerFromMetadataCtx() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr != "" { + return + } + if got.ID != tt.want.ID { + t.Errorf("workerFromMetadataCtx() id = %v, want %v", got.ID, tt.want.ID) + } + }) } } func TestWorker_NextInst(t *testing.T) { - w := New("test", "testEnv") + w := newWorker() instIDs := map[string]struct{}{} for i := 0; i < 100; i++ { @@ -57,7 +111,7 @@ func TestWorker_NextInst(t *testing.T) { } func TestWorker_GetProcessBundleDescriptor(t *testing.T) { - w := New("test", "testEnv") + w := newWorker() id := "available" w.Descriptors[id] = &fnpb.ProcessBundleDescriptor{ @@ -87,19 +141,21 @@ func serveTestWorker(t *testing.T) (context.Context, *W, *grpc.ClientConn) { ctx, cancelFn := context.WithCancel(context.Background()) t.Cleanup(cancelFn) - w := New("test", "testEnv") + g := grpc.NewServer() lis := bufconn.Listen(2048) - w.lis = lis - t.Cleanup(func() { w.Stop() }) - go w.Serve() - - clientConn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { - return lis.DialContext(ctx) - }), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + mw := NewMultiplexW(lis, g, slog.Default()) + t.Cleanup(func() { g.Stop() }) + go g.Serve(lis) + w := mw.MakeWorker("test", "testEnv") + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs("worker_id", w.ID)) + ctx = grpcx.WriteWorkerID(ctx, w.ID) + conn, err := grpc.DialContext(ctx, w.Endpoint(), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return lis.Dial() + }), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { t.Fatal("couldn't create bufconn grpc connection:", err) } - return ctx, w, clientConn + return ctx, w, conn } type closeSend func() @@ -465,3 +521,10 @@ func TestWorker_State_MultimapSideInput(t *testing.T) { }) } } + +func newWorker() *W { + mw := &MultiplexW{ + pool: map[string]*W{}, + } + return mw.MakeWorker("test", "testEnv") +} diff --git a/sdks/java/container/license_scripts/dep_urls_java.yaml b/sdks/java/container/license_scripts/dep_urls_java.yaml index 781a0decda78..cdb625bea447 100644 --- a/sdks/java/container/license_scripts/dep_urls_java.yaml +++ b/sdks/java/container/license_scripts/dep_urls_java.yaml @@ -46,7 +46,7 @@ jaxen: '1.1.6': type: "3-Clause BSD" libraries-bom: - '26.49.0': + '26.50.0': license: "https://raw.githubusercontent.com/GoogleCloudPlatform/cloud-opensource-java/master/LICENSE" type: "Apache License 2.0" paranamer: diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java index 78ea34503e54..2981046a0a41 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java @@ -110,6 +110,16 @@ enum LogLevel { void setLogMdc(boolean value); + /** This option controls whether logging will be redirected through the FnApi. */ + @Description( + "Controls whether logging will be redirected through the FnApi. In normal usage, setting " + + "this to a non-default value will cause log messages to be dropped.") + @Default.Boolean(true) + @Hidden + boolean getEnableLogViaFnApi(); + + void setEnableLogViaFnApi(boolean enableLogViaFnApi); + /** * Size (in MB) of each grouping table used to pre-combine elements. Larger values may reduce the * amount of data shuffled. If unset, defaults to 100 MB. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java index 7dccc0ebb903..c8e5bfecec20 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java @@ -56,7 +56,7 @@ public static PTransform, PCollection> fromR } /** - * Convert a {@link PCollection}{@literal } into a {@link PCollection}{@literal }. + * Convert a {@link PCollection}{@literal } into a {@link PCollection}{@literal }. * *

The output schema will be inferred using the schema registry. A schema must be registered * for this type, or the conversion will fail. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index fb2321328b32..742547b9b6c3 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -1593,7 +1593,7 @@ public void populateDisplayData(DisplayData.Builder builder) { @RunWith(JUnit4.class) public static class BundleFinalizationTests extends SharedTestBase implements Serializable { private abstract static class BundleFinalizingDoFn extends DoFn, String> { - private static final long MAX_ATTEMPTS = 3000; + private static final long MAX_ATTEMPTS = 100; // We use the UUID to uniquely identify this DoFn in case this test is run with // other tests in the same JVM. private static final Map WAS_FINALIZED = new HashMap(); @@ -1637,9 +1637,15 @@ public Void apply(Iterable input) { public void testBundleFinalization() { TestStream.Builder> stream = TestStream.create(KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of())); - for (long i = 0; i < BundleFinalizingDoFn.MAX_ATTEMPTS; ++i) { + long attemptCap = BundleFinalizingDoFn.MAX_ATTEMPTS - 1; + for (long i = 0; i < attemptCap; ++i) { stream = stream.addElements(KV.of("key" + (i % 10), i)); } + // Advance the time, and add the final element. This allows Finalization + // check mechanism to work without being sensitive to how bundles are + // produced by a runner. + stream = stream.advanceWatermarkTo(new Instant(10)); + stream = stream.addElements(KV.of("key" + (attemptCap % 10), attemptCap)); PCollection output = pipeline .apply(stream.advanceWatermarkToInfinity()) @@ -1677,6 +1683,8 @@ public void testBundleFinalizationWithState() { for (long i = 0; i < BundleFinalizingDoFn.MAX_ATTEMPTS; ++i) { stream = stream.addElements(KV.of("key" + (i % 10), i)); } + // Stateful execution is already per-key, so it is unnecessary to add a + // "final" element to attempt additional bundles to validate finalization. PCollection output = pipeline .apply(stream.advanceWatermarkToInfinity()) @@ -1715,9 +1723,15 @@ public void processElement( public void testBundleFinalizationWithSideInputs() { TestStream.Builder> stream = TestStream.create(KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of())); - for (long i = 0; i < BundleFinalizingDoFn.MAX_ATTEMPTS; ++i) { + long attemptCap = BundleFinalizingDoFn.MAX_ATTEMPTS - 1; + for (long i = 0; i < attemptCap; ++i) { stream = stream.addElements(KV.of("key" + (i % 10), i)); } + // Advance the time, and add the final element. This allows Finalization + // check mechanism to work without being sensitive to how bundles are + // produced by a runner. + stream = stream.advanceWatermarkTo(GlobalWindow.INSTANCE.maxTimestamp()); + stream = stream.addElements(KV.of("key" + (attemptCap % 10), attemptCap)); PCollectionView sideInput = pipeline.apply(Create.of("sideInput value")).apply(View.asSingleton()); PCollection output = diff --git a/sdks/java/extensions/sql/expansion-service/build.gradle b/sdks/java/extensions/sql/expansion-service/build.gradle index b8d78e4e1bb9..024041e40b36 100644 --- a/sdks/java/extensions/sql/expansion-service/build.gradle +++ b/sdks/java/extensions/sql/expansion-service/build.gradle @@ -48,5 +48,8 @@ task runExpansionService (type: JavaExec) { } shadowJar { + manifest { + attributes(["Multi-Release": true]) + } outputs.upToDateWhen { false } } \ No newline at end of file diff --git a/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/logging/BeamFnLoggingClientBenchmark.java b/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/logging/BeamFnLoggingClientBenchmark.java index b9e4b20db00b..745e1f078646 100644 --- a/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/logging/BeamFnLoggingClientBenchmark.java +++ b/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/logging/BeamFnLoggingClientBenchmark.java @@ -24,6 +24,8 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.beam.fn.harness.logging.BeamFnLoggingClient; import org.apache.beam.fn.harness.logging.BeamFnLoggingMDC; +import org.apache.beam.fn.harness.logging.LoggingClient; +import org.apache.beam.fn.harness.logging.LoggingClientFactory; import org.apache.beam.fn.harness.logging.QuotaEvent; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnLoggingGrpc; @@ -80,7 +82,7 @@ public void onCompleted() { /** Setup a simple logging service and configure the {@link BeamFnLoggingClient}. */ @State(Scope.Benchmark) public static class ManageLoggingClientAndService { - public final BeamFnLoggingClient loggingClient; + public final LoggingClient loggingClient; public final CallCountLoggingService loggingService; public final Server server; @@ -98,7 +100,7 @@ public ManageLoggingClientAndService() { .build(); server.start(); loggingClient = - BeamFnLoggingClient.createAndStart( + LoggingClientFactory.createAndStart( PipelineOptionsFactory.create(), apiServiceDescriptor, managedChannelFactory::forDescriptor); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index 9df9f12bc52b..3c8784d0ee42 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -39,7 +39,8 @@ import org.apache.beam.fn.harness.control.ProcessBundleHandler; import org.apache.beam.fn.harness.data.BeamFnDataGrpcClient; import org.apache.beam.fn.harness.debug.DataSampler; -import org.apache.beam.fn.harness.logging.BeamFnLoggingClient; +import org.apache.beam.fn.harness.logging.LoggingClient; +import org.apache.beam.fn.harness.logging.LoggingClientFactory; import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache; import org.apache.beam.fn.harness.status.BeamFnStatusClient; import org.apache.beam.fn.harness.stream.HarnessStreamObserverFactories; @@ -62,6 +63,7 @@ import org.apache.beam.sdk.options.ExecutorOptions; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.SdkHarnessOptions; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.sdk.util.construction.PipelineOptionsTranslation; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat; @@ -283,8 +285,8 @@ public static void main( // The logging client variable is not used per se, but during its lifetime (until close()) it // intercepts logging and sends it to the logging service. - try (BeamFnLoggingClient logging = - BeamFnLoggingClient.createAndStart( + try (LoggingClient logging = + LoggingClientFactory.createAndStart( options, loggingApiServiceDescriptor, channelFactory::forDescriptor)) { LOG.info("Fn Harness started"); // Register standard file systems. @@ -410,7 +412,11 @@ private BeamFnApi.ProcessBundleDescriptor loadDescriptor(String id) { outboundObserverFactory, executorService, handlers); - CompletableFuture.anyOf(control.terminationFuture(), logging.terminationFuture()).get(); + if (options.as(SdkHarnessOptions.class).getEnableLogViaFnApi()) { + CompletableFuture.anyOf(control.terminationFuture(), logging.terminationFuture()).get(); + } else { + control.terminationFuture().get(); + } if (beamFnStatusClient != null) { beamFnStatusClient.close(); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java index 7812d8c0bc30..112104e4d251 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java @@ -68,7 +68,7 @@ /** * Configures {@link java.util.logging} to send all {@link LogRecord}s via the Beam Fn Logging API. */ -public class BeamFnLoggingClient implements AutoCloseable { +public class BeamFnLoggingClient implements LoggingClient { private static final String ROOT_LOGGER_NAME = ""; private static final ImmutableMap LOG_LEVEL_MAP = ImmutableMap.builder() @@ -119,7 +119,7 @@ public class BeamFnLoggingClient implements AutoCloseable { */ private @Nullable Thread logEntryHandlerThread = null; - public static BeamFnLoggingClient createAndStart( + static BeamFnLoggingClient createAndStart( PipelineOptions options, Endpoints.ApiServiceDescriptor apiServiceDescriptor, Function channelFactory) { @@ -383,6 +383,7 @@ void flushFinalLogs(@UnderInitialization BeamFnLoggingClient this) { } } + @Override public CompletableFuture terminationFuture() { checkNotNull(bufferedLogConsumer, "BeamFnLoggingClient not fully started"); return bufferedLogConsumer; diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/package-info.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/LoggingClient.java similarity index 80% rename from sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/package-info.java rename to sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/LoggingClient.java index 6a36686cd8ab..3c1972ec643d 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/package-info.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/LoggingClient.java @@ -15,6 +15,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.beam.fn.harness.logging; -/** Transforms for reading and writing from Amazon Kinesis. */ -package org.apache.beam.sdk.io.kinesis; +import java.util.concurrent.CompletableFuture; + +public interface LoggingClient extends AutoCloseable { + + CompletableFuture terminationFuture(); +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/LoggingClientFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/LoggingClientFactory.java new file mode 100644 index 000000000000..f61b8d3c4ba3 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/LoggingClientFactory.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.logging; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import org.apache.beam.model.pipeline.v1.Endpoints; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.SdkHarnessOptions; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; + +/** + * A factory for {@link LoggingClient}s. Provides {@link BeamFnLoggingClient} if the logging service + * is enabled, otherwise provides a no-op client. + */ +public class LoggingClientFactory { + + private LoggingClientFactory() {} + + /** + * A factory for {@link LoggingClient}s. Provides {@link BeamFnLoggingClient} if the logging + * service is enabled, otherwise provides a no-op client. + */ + public static LoggingClient createAndStart( + PipelineOptions options, + Endpoints.ApiServiceDescriptor apiServiceDescriptor, + Function channelFactory) { + if (options.as(SdkHarnessOptions.class).getEnableLogViaFnApi()) { + return BeamFnLoggingClient.createAndStart(options, apiServiceDescriptor, channelFactory); + } else { + return new NoOpLoggingClient(); + } + } + + static final class NoOpLoggingClient implements LoggingClient { + @Override + public CompletableFuture terminationFuture() { + return CompletableFuture.completedFuture(new Object()); + } + + @Override + public void close() throws Exception {} + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java index cc9204a79062..ebfb2b3c41ad 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java @@ -30,6 +30,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.function.Supplier; import org.apache.beam.fn.harness.Cache; import org.apache.beam.fn.harness.Caches; @@ -105,6 +106,12 @@ private static class WrappedObservingIterator extends ElementByteSizeObservab private boolean observerNeedsAdvance = false; private boolean exceptionLogged = false; + private final Random randomGenerator = new Random(); + + // Lowest sampling probability: 0.001%. + private static final int SAMPLING_TOKEN_UPPER_BOUND = 1000000; + private static final int SAMPLING_CUTOFF = 10; + private int samplingToken = 0; static WrappedObservingIterator create( Iterator iterator, org.apache.beam.sdk.coders.Coder elementCoder) { @@ -125,6 +132,18 @@ private WrappedObservingIterator( this.elementCoder = elementCoder; } + private boolean sampleElement() { + // Sampling probability decreases as the element count is increasing. + // We unconditionally sample the first samplingCutoff elements. For the + // next samplingCutoff elements, the sampling probability drops from 100% + // to 50%. The probability of sampling the Nth element is: + // min(1, samplingCutoff / N), with an additional lower bound of + // samplingCutoff / samplingTokenUpperBound. This algorithm may be refined + // later. + samplingToken = Math.min(samplingToken + 1, SAMPLING_TOKEN_UPPER_BOUND); + return randomGenerator.nextInt(samplingToken) < SAMPLING_CUTOFF; + } + @Override public boolean hasNext() { if (observerNeedsAdvance) { @@ -138,15 +157,19 @@ public boolean hasNext() { public T next() { T value = wrappedIterator.next(); try { - elementCoder.registerByteSizeObserver(value, observerProxy); - if (observerProxy.getIsLazy()) { - // The observer will only be notified of bytes as the result - // is used. We defer advancing the observer until hasNext in an - // attempt to capture those bytes. - observerNeedsAdvance = true; - } else { - observerNeedsAdvance = false; - observerProxy.advance(); + if (sampleElement() || elementCoder.isRegisterByteSizeObserverCheap(value)) { + elementCoder.registerByteSizeObserver(value, observerProxy); + observerProxy.setScalingFactor( + Math.max(samplingToken, SAMPLING_CUTOFF) / (double) SAMPLING_CUTOFF); + if (observerProxy.getIsLazy()) { + // The observer will only be notified of bytes as the result + // is used. We defer advancing the observer until hasNext in an + // attempt to capture those bytes. + observerNeedsAdvance = true; + } else { + observerNeedsAdvance = false; + observerProxy.advance(); + } } } catch (Exception e) { if (!exceptionLogged) { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java index 245c87f3e194..96e68586b3a5 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java @@ -47,8 +47,9 @@ import org.apache.beam.fn.harness.control.ExecutionStateSampler; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker; import org.apache.beam.fn.harness.debug.DataSampler; -import org.apache.beam.fn.harness.logging.BeamFnLoggingClient; import org.apache.beam.fn.harness.logging.BeamFnLoggingMDC; +import org.apache.beam.fn.harness.logging.LoggingClient; +import org.apache.beam.fn.harness.logging.LoggingClientFactory; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor; import org.apache.beam.model.fnexecution.v1.BeamFnLoggingGrpc; @@ -647,8 +648,8 @@ public StreamObserver logging( // Start the test within the logging context. This reroutes logging through to the boiler-plate // that was set up // earlier. - try (BeamFnLoggingClient ignored = - BeamFnLoggingClient.createAndStart( + try (LoggingClient ignored = + LoggingClientFactory.createAndStart( PipelineOptionsFactory.create(), apiServiceDescriptor, (Endpoints.ApiServiceDescriptor descriptor) -> channel)) { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java index ffc1ba62cb56..14775ed0b6fb 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java @@ -269,7 +269,8 @@ public void testByteObservingStateBackedIterable() throws Exception { .sum(); observer.advance(); // 5 comes from size and hasNext (see IterableLikeCoder) - assertEquals(iterateBytes + 5, observer.total); + // observer due to sampling should observe fewer bytes + assertTrue(iterateBytes + 5 >= observer.total); } } diff --git a/sdks/java/io/amazon-web-services/build.gradle b/sdks/java/io/amazon-web-services/build.gradle deleted file mode 100644 index b9ed51fbbf77..000000000000 --- a/sdks/java/io/amazon-web-services/build.gradle +++ /dev/null @@ -1,74 +0,0 @@ -import groovy.json.JsonOutput - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -plugins { - id 'org.apache.beam.module' -} - -applyJavaNature( - automaticModuleName: 'org.apache.beam.sdk.io.aws', -) - -provideIntegrationTestingDependencies() -enableJavaPerformanceTesting() - -description = "Apache Beam :: SDKs :: Java :: IO :: Amazon Web Services" -ext.summary = "IO library to read and write Amazon Web Services services from Beam." - -dependencies { - implementation library.java.vendored_guava_32_1_2_jre - implementation project(path: ":sdks:java:core", configuration: "shadow") - implementation library.java.aws_java_sdk_cloudwatch - implementation library.java.aws_java_sdk_core - implementation library.java.aws_java_sdk_dynamodb - implementation library.java.aws_java_sdk_s3 - implementation library.java.aws_java_sdk_sns - implementation library.java.aws_java_sdk_sqs - implementation library.java.aws_java_sdk_sts - implementation library.java.jackson_core - implementation library.java.jackson_annotations - implementation library.java.jackson_databind - implementation library.java.slf4j_api - implementation library.java.joda_time - implementation library.java.http_core - runtimeOnly library.java.commons_codec - runtimeOnly "org.apache.httpcomponents:httpclient:4.5.12" - testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") - testImplementation project(path: ":sdks:java:extensions:avro", configuration: "testRuntimeMigration") - testImplementation project(path: ":sdks:java:io:common") - testImplementation "io.findify:s3mock_2.12:0.2.6" - testImplementation library.java.commons_lang3 - testImplementation library.java.hamcrest - testImplementation library.java.mockito_core - testImplementation library.java.junit - testImplementation library.java.testcontainers_localstack - testImplementation "org.assertj:assertj-core:3.11.1" - testImplementation 'org.elasticmq:elasticmq-rest-sqs_2.12:0.15.6' - testRuntimeOnly library.java.slf4j_jdk14 - testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") -} - -test { - systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ - '--awsRegion=us-west-2', - '--awsCredentialsProvider={"@type" : "AWSStaticCredentialsProvider", "awsAccessKeyId" : "key_id_value","awsSecretKey" : "secret_value"}' - ]) - maxParallelForks 4 -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/AwsCoders.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/AwsCoders.java deleted file mode 100644 index 501bfc015860..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/AwsCoders.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.coders; - -import com.amazonaws.ResponseMetadata; -import com.amazonaws.http.HttpResponse; -import com.amazonaws.http.SdkHttpMetadata; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.Map; -import java.util.Optional; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.CustomCoder; -import org.apache.beam.sdk.coders.MapCoder; -import org.apache.beam.sdk.coders.NullableCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; - -/** {@link Coder}s for common AWS SDK objects. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public final class AwsCoders { - - private AwsCoders() {} - - /** - * Returns a new coder for ResponseMetadata. - * - * @return the ResponseMetadata coder - */ - public static Coder responseMetadata() { - return ResponseMetadataCoder.of(); - } - - /** - * Returns a new coder for SdkHttpMetadata. - * - * @return the SdkHttpMetadata coder - */ - public static Coder sdkHttpMetadata() { - return new SdkHttpMetadataCoder(true); - } - - /** - * Returns a new coder for SdkHttpMetadata that does not serialize the response headers. - * - * @return the SdkHttpMetadata coder - */ - public static Coder sdkHttpMetadataWithoutHeaders() { - return new SdkHttpMetadataCoder(false); - } - - private static class ResponseMetadataCoder extends AtomicCoder { - - private static final Coder> METADATA_ENCODER = - NullableCoder.of(MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); - private static final ResponseMetadataCoder INSTANCE = new ResponseMetadataCoder(); - - private ResponseMetadataCoder() {} - - public static ResponseMetadataCoder of() { - return INSTANCE; - } - - @Override - public void encode(ResponseMetadata value, OutputStream outStream) - throws CoderException, IOException { - METADATA_ENCODER.encode( - ImmutableMap.of(ResponseMetadata.AWS_REQUEST_ID, value.getRequestId()), outStream); - } - - @Override - public ResponseMetadata decode(InputStream inStream) throws CoderException, IOException { - return new ResponseMetadata(METADATA_ENCODER.decode(inStream)); - } - } - - private static class SdkHttpMetadataCoder extends CustomCoder { - - private static final Coder STATUS_CODE_CODER = VarIntCoder.of(); - private static final Coder> HEADERS_ENCODER = - NullableCoder.of(MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); - - private final boolean includeHeaders; - - protected SdkHttpMetadataCoder(boolean includeHeaders) { - this.includeHeaders = includeHeaders; - } - - @Override - public void encode(SdkHttpMetadata value, OutputStream outStream) - throws CoderException, IOException { - STATUS_CODE_CODER.encode(value.getHttpStatusCode(), outStream); - if (includeHeaders) { - HEADERS_ENCODER.encode(value.getHttpHeaders(), outStream); - } - } - - @Override - public SdkHttpMetadata decode(InputStream inStream) throws CoderException, IOException { - final int httpStatusCode = STATUS_CODE_CODER.decode(inStream); - HttpResponse httpResponse = new HttpResponse(null, null); - httpResponse.setStatusCode(httpStatusCode); - if (includeHeaders) { - Optional.ofNullable(HEADERS_ENCODER.decode(inStream)) - .ifPresent( - headers -> - headers.keySet().forEach(k -> httpResponse.addHeader(k, headers.get(k)))); - } - return SdkHttpMetadata.from(httpResponse); - } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - STATUS_CODE_CODER.verifyDeterministic(); - if (includeHeaders) { - HEADERS_ENCODER.verifyDeterministic(); - } - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoder.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoder.java deleted file mode 100644 index 4bdf8b51d3b2..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoder.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import com.amazonaws.services.dynamodbv2.model.AttributeValue; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.stream.Collectors; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.BooleanCoder; -import org.apache.beam.sdk.coders.ByteArrayCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.ListCoder; -import org.apache.beam.sdk.coders.MapCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; - -/** A {@link Coder} that serializes and deserializes the {@link AttributeValue} objects. */ -public class AttributeValueCoder extends AtomicCoder { - - /** Data type of each value type in AttributeValue object. */ - private enum AttributeValueType { - s, // for String - n, // for Number - b, // for Byte - sS, // for List of String - nS, // for List of Number - bS, // for List of Byte - m, // for Map of String and AttributeValue - l, // for list of AttributeValue - bOOL, // for Boolean - nULLValue, // for null - } - - private static final AttributeValueCoder INSTANCE = new AttributeValueCoder(); - - private static final ListCoder LIST_STRING_CODER = ListCoder.of(StringUtf8Coder.of()); - private static final ListCoder LIST_BYTE_CODER = ListCoder.of(ByteArrayCoder.of()); - - private static final ListCoder LIST_ATTRIBUTE_CODER = - ListCoder.of(AttributeValueCoder.of()); - private static final MapCoder MAP_ATTRIBUTE_CODER = - MapCoder.of(StringUtf8Coder.of(), AttributeValueCoder.of()); - - private AttributeValueCoder() {} - - public static AttributeValueCoder of() { - return INSTANCE; - } - - @Override - public void encode(AttributeValue value, OutputStream outStream) throws IOException { - - if (value.getS() != null) { - StringUtf8Coder.of().encode(AttributeValueType.s.toString(), outStream); - StringUtf8Coder.of().encode(value.getS(), outStream); - } else if (value.getN() != null) { - StringUtf8Coder.of().encode(AttributeValueType.n.toString(), outStream); - StringUtf8Coder.of().encode(value.getN(), outStream); - } else if (value.getBOOL() != null) { - StringUtf8Coder.of().encode(AttributeValueType.bOOL.toString(), outStream); - BooleanCoder.of().encode(value.getBOOL(), outStream); - } else if (value.getB() != null) { - StringUtf8Coder.of().encode(AttributeValueType.b.toString(), outStream); - ByteArrayCoder.of().encode(convertToByteArray(value.getB()), outStream); - } else if (value.getSS() != null) { - StringUtf8Coder.of().encode(AttributeValueType.sS.toString(), outStream); - LIST_STRING_CODER.encode(value.getSS(), outStream); - } else if (value.getNS() != null) { - StringUtf8Coder.of().encode(AttributeValueType.nS.toString(), outStream); - LIST_STRING_CODER.encode(value.getNS(), outStream); - } else if (value.getBS() != null) { - StringUtf8Coder.of().encode(AttributeValueType.bS.toString(), outStream); - LIST_BYTE_CODER.encode(convertToListByteArray(value.getBS()), outStream); - } else if (value.getL() != null) { - StringUtf8Coder.of().encode(AttributeValueType.l.toString(), outStream); - LIST_ATTRIBUTE_CODER.encode(value.getL(), outStream); - } else if (value.getM() != null) { - StringUtf8Coder.of().encode(AttributeValueType.m.toString(), outStream); - MAP_ATTRIBUTE_CODER.encode(value.getM(), outStream); - } else if (value.getNULL() != null) { - StringUtf8Coder.of().encode(AttributeValueType.nULLValue.toString(), outStream); - BooleanCoder.of().encode(value.getNULL(), outStream); - } else { - throw new CoderException("Unknown Type"); - } - } - - @Override - public AttributeValue decode(InputStream inStream) throws IOException { - AttributeValue attrValue = new AttributeValue(); - - String type = StringUtf8Coder.of().decode(inStream); - AttributeValueType attrType = AttributeValueType.valueOf(type); - - switch (attrType) { - case s: - attrValue.setS(StringUtf8Coder.of().decode(inStream)); - break; - case n: - attrValue.setN(StringUtf8Coder.of().decode(inStream)); - break; - case bOOL: - attrValue.setBOOL(BooleanCoder.of().decode(inStream)); - break; - case b: - attrValue.setB(ByteBuffer.wrap(ByteArrayCoder.of().decode(inStream))); - break; - case sS: - attrValue.setSS(LIST_STRING_CODER.decode(inStream)); - break; - case nS: - attrValue.setNS(LIST_STRING_CODER.decode(inStream)); - break; - case bS: - attrValue.setBS(convertToListByteBuffer(LIST_BYTE_CODER.decode(inStream))); - break; - case l: - attrValue.setL(LIST_ATTRIBUTE_CODER.decode(inStream)); - break; - case m: - attrValue.setM(MAP_ATTRIBUTE_CODER.decode(inStream)); - break; - case nULLValue: - attrValue.setNULL(BooleanCoder.of().decode(inStream)); - break; - default: - throw new CoderException("Unknown Type"); - } - - return attrValue; - } - - private List convertToListByteArray(List listByteBuffer) { - return listByteBuffer.stream().map(this::convertToByteArray).collect(Collectors.toList()); - } - - private byte[] convertToByteArray(ByteBuffer buffer) { - byte[] bytes = new byte[buffer.remaining()]; - buffer.get(bytes); - buffer.position(buffer.position() - bytes.length); - return bytes; - } - - private List convertToListByteBuffer(List listByteArr) { - return listByteArr.stream().map(ByteBuffer::wrap).collect(Collectors.toList()); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AwsClientsProvider.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AwsClientsProvider.java deleted file mode 100644 index f2d13b144e8d..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AwsClientsProvider.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; -import java.io.Serializable; - -/** - * Provides instances of AWS clients. - * - *

Please note, that any instance of {@link AwsClientsProvider} must be {@link Serializable} to - * ensure it can be sent to worker machines. - */ -public interface AwsClientsProvider extends Serializable { - - /** @deprecated DynamoDBIO doesn't require a CloudWatch client */ - @Deprecated - @SuppressWarnings("return") - default AmazonCloudWatch getCloudWatchClient() { - return null; - } - - AmazonDynamoDB createDynamoDB(); -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/BasicDynamoDBProvider.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/BasicDynamoDBProvider.java deleted file mode 100644 index b4ee1be74abe..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/BasicDynamoDBProvider.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; -import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** Basic implementation of {@link AwsClientsProvider} used by default in {@link DynamoDBIO}. */ -public class BasicDynamoDBProvider implements AwsClientsProvider { - private final String accessKey; - private final String secretKey; - private final Regions region; - private final @Nullable String serviceEndpoint; - - BasicDynamoDBProvider( - String accessKey, String secretKey, Regions region, @Nullable String serviceEndpoint) { - checkArgument(accessKey != null, "accessKey can not be null"); - checkArgument(secretKey != null, "secretKey can not be null"); - checkArgument(region != null, "region can not be null"); - this.accessKey = accessKey; - this.secretKey = secretKey; - this.region = region; - this.serviceEndpoint = serviceEndpoint; - } - - private AWSCredentialsProvider getCredentialsProvider() { - return new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)); - } - - @Override - public AmazonDynamoDB createDynamoDB() { - AmazonDynamoDBClientBuilder clientBuilder = - AmazonDynamoDBClientBuilder.standard().withCredentials(getCredentialsProvider()); - - if (serviceEndpoint == null) { - clientBuilder.withRegion(region); - } else { - clientBuilder.withEndpointConfiguration( - new EndpointConfiguration(serviceEndpoint, region.getName())); - } - - return clientBuilder.build(); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIO.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIO.java deleted file mode 100644 index e2c04c58b45d..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIO.java +++ /dev/null @@ -1,630 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import static java.util.stream.Collectors.groupingBy; -import static java.util.stream.Collectors.mapping; -import static java.util.stream.Collectors.toList; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import com.amazonaws.regions.Regions; -import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; -import com.amazonaws.services.dynamodbv2.model.AmazonDynamoDBException; -import com.amazonaws.services.dynamodbv2.model.AttributeValue; -import com.amazonaws.services.dynamodbv2.model.BatchWriteItemRequest; -import com.amazonaws.services.dynamodbv2.model.BatchWriteItemResult; -import com.amazonaws.services.dynamodbv2.model.ScanRequest; -import com.amazonaws.services.dynamodbv2.model.ScanResult; -import com.amazonaws.services.dynamodbv2.model.WriteRequest; -import com.google.auto.value.AutoValue; -import java.io.IOException; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Predicate; -import java.util.stream.Collectors; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.ListCoder; -import org.apache.beam.sdk.coders.MapCoder; -import org.apache.beam.sdk.coders.SerializableCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.Reshuffle; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.sdk.util.BackOffUtils; -import org.apache.beam.sdk.util.FluentBackoff; -import org.apache.beam.sdk.util.Sleeper; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PBegin; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.http.HttpStatus; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * {@link PTransform}s to read/write from/to Amazon - * DynamoDB. - * - *

Writing to DynamoDB

- * - *

Example usage: - * - *

{@code
- * PCollection data = ...;
- * data.apply(
- *           DynamoDBIO.write()
- *               .withWriteRequestMapperFn(
- *                   (SerializableFunction>)
- *                       //Transforming your T data into KV
- *                       t -> KV.of(tableName, writeRequest))
- *               .withRetryConfiguration(
- *                    DynamoDBIO.RetryConfiguration.create(5, Duration.standardMinutes(1)))
- *               .withAwsClientsProvider(new BasicDynamoDbProvider(accessKey, secretKey, region));
- * }
- * - *

As a client, you need to provide at least the following things: - * - *

    - *
  • Retry configuration - *
  • Specify AwsClientsProvider. You can pass on the default one BasicDynamoDbProvider - *
  • Mapper function with a table name to map or transform your object into KV - *
- * - * Note: AWS does not allow writing duplicate keys within a single batch operation. If - * primary keys possibly repeat in your stream (i.e. an upsert stream), you may encounter a - * `ValidationError`. To address this you have to provide the key names corresponding to your - * primary key using {@link Write#withDeduplicateKeys(List)}. Based on these keys only the last - * observed element is kept. Nevertheless, if no deduplication keys are provided, identical elements - * are still deduplicated. - * - *

Reading from DynamoDB

- * - *

Example usage: - * - *

{@code
- * PCollection>> output =
- *     pipeline.apply(
- *             DynamoDBIO.>>read()
- *                 .withAwsClientsProvider(new BasicDynamoDBProvider(accessKey, secretKey, region))
- *                 .withScanRequestFn(
- *                     (SerializableFunction)
- *                         input -> new ScanRequest(tableName).withTotalSegments(1))
- *                 .items());
- * }
- * - *

As a client, you need to provide at least the following things: - * - *

    - *
  • Specify AwsClientsProvider. You can pass on the default one BasicDynamoDBProvider - *
  • ScanRequestFn, which you build a ScanRequest object with at least table name and total - * number of segment. Note This number should base on the number of your workers - *
- * - * @deprecated Module beam-sdks-java-io-amazon-web-services is deprecated and will be - * eventually removed. Please migrate to {@link org.apache.beam.sdk.io.aws2.dynamodb.DynamoDBIO} - * in module beam-sdks-java-io-amazon-web-services2. - */ -@Deprecated -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public final class DynamoDBIO { - public static Read read() { - return new AutoValue_DynamoDBIO_Read.Builder().build(); - } - - public static Write write() { - return new AutoValue_DynamoDBIO_Write.Builder() - .setDeduplicateKeys(new ArrayList<>()) - .build(); - } - - /** Read data from DynamoDB and return ScanResult. */ - @AutoValue - public abstract static class Read extends PTransform> { - - abstract @Nullable AwsClientsProvider getAwsClientsProvider(); - - abstract @Nullable SerializableFunction getScanRequestFn(); - - abstract @Nullable Integer getSegmentId(); - - abstract @Nullable SerializableFunction getScanResultMapperFn(); - - abstract @Nullable Coder getCoder(); - - abstract Builder toBuilder(); - - @AutoValue.Builder - abstract static class Builder { - - abstract Builder setAwsClientsProvider(AwsClientsProvider awsClientsProvider); - - abstract Builder setScanRequestFn(SerializableFunction fn); - - abstract Builder setSegmentId(Integer segmentId); - - abstract Builder setScanResultMapperFn( - SerializableFunction scanResultMapperFn); - - abstract Builder setCoder(Coder coder); - - abstract Read build(); - } - - public Read withAwsClientsProvider(AwsClientsProvider awsClientsProvider) { - return toBuilder().setAwsClientsProvider(awsClientsProvider).build(); - } - - public Read withAwsClientsProvider( - String awsAccessKey, String awsSecretKey, Regions region, String serviceEndpoint) { - return withAwsClientsProvider( - new BasicDynamoDBProvider(awsAccessKey, awsSecretKey, region, serviceEndpoint)); - } - - public Read withAwsClientsProvider( - String awsAccessKey, String awsSecretKey, Regions region) { - return withAwsClientsProvider(awsAccessKey, awsSecretKey, region, null); - } - - /** - * Can't pass ScanRequest object directly from client since this object is not full - * serializable. - */ - public Read withScanRequestFn(SerializableFunction fn) { - return toBuilder().setScanRequestFn(fn).build(); - } - - private Read withSegmentId(Integer segmentId) { - checkArgument(segmentId != null, "segmentId can not be null"); - return toBuilder().setSegmentId(segmentId).build(); - } - - public Read withScanResultMapperFn(SerializableFunction scanResultMapperFn) { - checkArgument(scanResultMapperFn != null, "scanResultMapper can not be null"); - return toBuilder().setScanResultMapperFn(scanResultMapperFn).build(); - } - - public Read>> items() { - // safe cast as both mapper and coder are updated accordingly - Read>> self = (Read>>) this; - return self.withScanResultMapperFn(new DynamoDBIO.Read.ItemsMapper()) - .withCoder(ListCoder.of(MapCoder.of(StringUtf8Coder.of(), AttributeValueCoder.of()))); - } - - public Read withCoder(Coder coder) { - checkArgument(coder != null, "coder can not be null"); - return toBuilder().setCoder(coder).build(); - } - - @Override - public PCollection expand(PBegin input) { - LoggerFactory.getLogger(DynamoDBIO.class) - .warn( - "You are using a deprecated IO for DynamoDB. Please migrate to module " - + "'org.apache.beam:beam-sdks-java-io-amazon-web-services2'."); - - checkArgument((getScanRequestFn() != null), "withScanRequestFn() is required"); - checkArgument((getAwsClientsProvider() != null), "withAwsClientsProvider() is required"); - ScanRequest scanRequest = getScanRequestFn().apply(null); - checkArgument( - (scanRequest.getTotalSegments() != null && scanRequest.getTotalSegments() > 0), - "TotalSegments is required with withScanRequestFn() and greater zero"); - - PCollection> splits = - input.apply("Create", Create.of(this)).apply("Split", ParDo.of(new SplitFn<>())); - splits.setCoder(SerializableCoder.of(new TypeDescriptor>() {})); - - PCollection output = - splits - .apply("Reshuffle", Reshuffle.viaRandomKey()) - .apply("Read", ParDo.of(new ReadFn<>())); - output.setCoder(getCoder()); - return output; - } - - /** A {@link DoFn} to split {@link Read} elements by segment id. */ - private static class SplitFn extends DoFn, Read> { - @ProcessElement - public void processElement(@Element Read spec, OutputReceiver> out) { - ScanRequest scanRequest = spec.getScanRequestFn().apply(null); - for (int i = 0; i < scanRequest.getTotalSegments(); i++) { - out.output(spec.withSegmentId(i)); - } - } - } - - /** A {@link DoFn} executing the ScanRequest to read from DynamoDB. */ - private static class ReadFn extends DoFn, T> { - @ProcessElement - public void processElement(@Element Read spec, OutputReceiver out) { - AmazonDynamoDB client = spec.getAwsClientsProvider().createDynamoDB(); - Map lastEvaluatedKey = null; - - do { - ScanRequest scanRequest = spec.getScanRequestFn().apply(null); - scanRequest.setSegment(spec.getSegmentId()); - if (lastEvaluatedKey != null) { - scanRequest.withExclusiveStartKey(lastEvaluatedKey); - } - - ScanResult scanResult = client.scan(scanRequest); - out.output(spec.getScanResultMapperFn().apply(scanResult)); - lastEvaluatedKey = scanResult.getLastEvaluatedKey(); - } while (lastEvaluatedKey != null); // iterate until all records are fetched - } - } - - static final class ItemsMapper - implements SerializableFunction>> { - @Override - public List> apply(@Nullable ScanResult scanResult) { - if (scanResult == null) { - return Collections.emptyList(); - } - return scanResult.getItems(); - } - } - } - - /** - * A POJO encapsulating a configuration for retry behavior when issuing requests to DynamoDB. A - * retry will be attempted until the maxAttempts or maxDuration is exceeded, whichever comes - * first, for any of the following exceptions: - * - *
    - *
  • {@link IOException} - *
- */ - @AutoValue - public abstract static class RetryConfiguration implements Serializable { - private static final Duration DEFAULT_INITIAL_DURATION = Duration.standardSeconds(5); - - @VisibleForTesting - static final RetryPredicate DEFAULT_RETRY_PREDICATE = new DefaultRetryPredicate(); - - abstract int getMaxAttempts(); - - abstract Duration getMaxDuration(); - - abstract Duration getInitialDuration(); - - abstract DynamoDBIO.RetryConfiguration.RetryPredicate getRetryPredicate(); - - abstract DynamoDBIO.RetryConfiguration.Builder builder(); - - public static DynamoDBIO.RetryConfiguration create(int maxAttempts, Duration maxDuration) { - return create(maxAttempts, maxDuration, DEFAULT_INITIAL_DURATION); - } - - static DynamoDBIO.RetryConfiguration create( - int maxAttempts, Duration maxDuration, Duration initialDuration) { - checkArgument(maxAttempts > 0, "maxAttempts should be greater than 0"); - checkArgument( - maxDuration != null && maxDuration.isLongerThan(Duration.ZERO), - "maxDuration should be greater than 0"); - checkArgument( - initialDuration != null && initialDuration.isLongerThan(Duration.ZERO), - "initialDuration should be greater than 0"); - - return new AutoValue_DynamoDBIO_RetryConfiguration.Builder() - .setMaxAttempts(maxAttempts) - .setMaxDuration(maxDuration) - .setInitialDuration(initialDuration) - .setRetryPredicate(DEFAULT_RETRY_PREDICATE) - .build(); - } - - @AutoValue.Builder - abstract static class Builder { - abstract DynamoDBIO.RetryConfiguration.Builder setMaxAttempts(int maxAttempts); - - abstract DynamoDBIO.RetryConfiguration.Builder setMaxDuration(Duration maxDuration); - - abstract DynamoDBIO.RetryConfiguration.Builder setInitialDuration(Duration initialDuration); - - abstract DynamoDBIO.RetryConfiguration.Builder setRetryPredicate( - RetryPredicate retryPredicate); - - abstract DynamoDBIO.RetryConfiguration build(); - } - - /** - * An interface used to control if we retry the BatchWriteItemRequest call when a {@link - * Throwable} occurs. If {@link RetryPredicate#test(Object)} returns true, {@link Write} tries - * to resend the requests to the DynamoDB server if the {@link RetryConfiguration} permits it. - */ - @FunctionalInterface - interface RetryPredicate extends Predicate, Serializable {} - - private static class DefaultRetryPredicate implements RetryPredicate { - private static final ImmutableSet ELIGIBLE_CODES = - ImmutableSet.of(HttpStatus.SC_SERVICE_UNAVAILABLE); - - @Override - public boolean test(Throwable throwable) { - return (throwable instanceof IOException - || (throwable instanceof AmazonDynamoDBException) - || (throwable instanceof AmazonDynamoDBException - && ELIGIBLE_CODES.contains(((AmazonDynamoDBException) throwable).getStatusCode()))); - } - } - } - - /** Write a PCollection data into DynamoDB. */ - @AutoValue - public abstract static class Write extends PTransform, PCollection> { - - abstract @Nullable AwsClientsProvider getAwsClientsProvider(); - - abstract @Nullable RetryConfiguration getRetryConfiguration(); - - abstract @Nullable SerializableFunction> getWriteItemMapperFn(); - - abstract List getDeduplicateKeys(); - - abstract Builder builder(); - - @AutoValue.Builder - abstract static class Builder { - - abstract Builder setAwsClientsProvider(AwsClientsProvider awsClientsProvider); - - abstract Builder setRetryConfiguration(RetryConfiguration retryConfiguration); - - abstract Builder setWriteItemMapperFn( - SerializableFunction> writeItemMapperFn); - - abstract Builder setDeduplicateKeys(List deduplicateKeys); - - abstract Write build(); - } - - public Write withAwsClientsProvider(AwsClientsProvider awsClientsProvider) { - return builder().setAwsClientsProvider(awsClientsProvider).build(); - } - - public Write withAwsClientsProvider( - String awsAccessKey, String awsSecretKey, Regions region, String serviceEndpoint) { - return withAwsClientsProvider( - new BasicDynamoDBProvider(awsAccessKey, awsSecretKey, region, serviceEndpoint)); - } - - public Write withAwsClientsProvider( - String awsAccessKey, String awsSecretKey, Regions region) { - return withAwsClientsProvider(awsAccessKey, awsSecretKey, region, null); - } - - /** - * Provides configuration to retry a failed request to publish a set of records to DynamoDB. - * Users should consider that retrying might compound the underlying problem which caused the - * initial failure. Users should also be aware that once retrying is exhausted the error is - * surfaced to the runner which may then opt to retry the current partition in entirety - * or abort if the max number of retries of the runner is completed. Retrying uses an - * exponential backoff algorithm, with minimum backoff of 5 seconds and then surfacing the error - * once the maximum number of retries or maximum configuration duration is exceeded. - * - *

Example use: - * - *

{@code
-     * DynamoDBIO.write()
-     *   .withRetryConfiguration(DynamoDBIO.RetryConfiguration.create(5, Duration.standardMinutes(1))
-     *   ...
-     * }
- * - * @param retryConfiguration the rules which govern the retry behavior - * @return the {@link DynamoDBIO.Write} with retrying configured - */ - public Write withRetryConfiguration(RetryConfiguration retryConfiguration) { - checkArgument(retryConfiguration != null, "retryConfiguration is required"); - return builder().setRetryConfiguration(retryConfiguration).build(); - } - - public Write withWriteRequestMapperFn( - SerializableFunction> writeItemMapperFn) { - return builder().setWriteItemMapperFn(writeItemMapperFn).build(); - } - - public Write withDeduplicateKeys(List deduplicateKeys) { - return builder().setDeduplicateKeys(deduplicateKeys).build(); - } - - @Override - public PCollection expand(PCollection input) { - LoggerFactory.getLogger(DynamoDBIO.class) - .warn( - "You are using a deprecated IO for DynamoDB. Please migrate to module " - + "'org.apache.beam:beam-sdks-java-io-amazon-web-services2'."); - - return input.apply(ParDo.of(new WriteFn<>(this))); - } - - static class WriteFn extends DoFn { - @VisibleForTesting - static final String RETRY_ERROR_LOG = "Error writing items to DynamoDB [attempts:{}]: {}"; - - private static final String RESUME_ERROR_LOG = - "Error writing remaining unprocessed items to DynamoDB: {}"; - - private static final String ERROR_NO_RETRY = - "Error writing to DynamoDB. No attempt made to retry"; - private static final String ERROR_RETRIES_EXCEEDED = - "Error writing to DynamoDB after %d attempt(s). No more attempts allowed"; - private static final String ERROR_UNPROCESSED_ITEMS = - "Error writing to DynamoDB. Unprocessed items remaining"; - - private transient FluentBackoff resumeBackoff; // resume from partial failures (unlimited) - private transient FluentBackoff retryBackoff; // retry erroneous calls (default: none) - - private static final Logger LOG = LoggerFactory.getLogger(WriteFn.class); - private static final Counter DYNAMO_DB_WRITE_FAILURES = - Metrics.counter(WriteFn.class, "DynamoDB_Write_Failures"); - - private static final int BATCH_SIZE = 25; - private transient AmazonDynamoDB client; - private final DynamoDBIO.Write spec; - private Map>, KV> batch; - - WriteFn(DynamoDBIO.Write spec) { - this.spec = spec; - } - - @Setup - public void setup() { - client = spec.getAwsClientsProvider().createDynamoDB(); - resumeBackoff = FluentBackoff.DEFAULT; // resume from partial failures (unlimited) - retryBackoff = FluentBackoff.DEFAULT.withMaxRetries(0); // retry on errors (default: none) - - RetryConfiguration retryConfig = spec.getRetryConfiguration(); - if (retryConfig != null) { - resumeBackoff = resumeBackoff.withInitialBackoff(retryConfig.getInitialDuration()); - retryBackoff = - retryBackoff - .withMaxRetries(retryConfig.getMaxAttempts() - 1) - .withInitialBackoff(retryConfig.getInitialDuration()) - .withMaxCumulativeBackoff(retryConfig.getMaxDuration()); - } - } - - @StartBundle - public void startBundle(StartBundleContext context) { - batch = new HashMap<>(); - } - - @ProcessElement - public void processElement(ProcessContext context) throws Exception { - final KV writeRequest = - spec.getWriteItemMapperFn().apply(context.element()); - batch.put( - KV.of(writeRequest.getKey(), extractDeduplicateKeyValues(writeRequest.getValue())), - writeRequest); - if (batch.size() >= BATCH_SIZE) { - flushBatch(); - } - } - - private Map extractDeduplicateKeyValues(WriteRequest request) { - List deduplicationKeys = spec.getDeduplicateKeys(); - Map attributes = Collections.emptyMap(); - - if (request.getPutRequest() != null) { - attributes = request.getPutRequest().getItem(); - } else if (request.getDeleteRequest() != null) { - attributes = request.getDeleteRequest().getKey(); - } - - if (attributes.isEmpty() || deduplicationKeys.isEmpty()) { - return attributes; - } - - return attributes.entrySet().stream() - .filter(entry -> deduplicationKeys.contains(entry.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } - - @FinishBundle - public void finishBundle(FinishBundleContext context) throws Exception { - flushBatch(); - } - - private void flushBatch() throws IOException, InterruptedException { - if (batch.isEmpty()) { - return; - } - try { - // Group values KV by tableName - // Note: The original order of arrival is lost reading the map entries. - Map> writesPerTable = - batch.values().stream() - .collect(groupingBy(KV::getKey, mapping(KV::getValue, toList()))); - - // Backoff used to resume from partial failures - BackOff resume = resumeBackoff.backoff(); - do { - BatchWriteItemRequest batchRequest = new BatchWriteItemRequest(writesPerTable); - // If unprocessed items remain, we have to resume the operation (with backoff) - writesPerTable = writeWithRetries(batchRequest).getUnprocessedItems(); - } while (!writesPerTable.isEmpty() && BackOffUtils.next(Sleeper.DEFAULT, resume)); - - if (!writesPerTable.isEmpty()) { - DYNAMO_DB_WRITE_FAILURES.inc(); - LOG.error(RESUME_ERROR_LOG, writesPerTable); - throw new IOException(ERROR_UNPROCESSED_ITEMS); - } - } finally { - batch.clear(); - } - } - - /** - * Write batch of items to DynamoDB and potentially retry in case of exceptions. Though, in - * case of a partial failure, unprocessed items remain but the request succeeds. This has to - * be handled by the caller. - */ - private BatchWriteItemResult writeWithRetries(BatchWriteItemRequest request) - throws IOException, InterruptedException { - BackOff backoff = retryBackoff.backoff(); - Exception lastThrown; - - int attempt = 0; - do { - attempt++; - try { - return client.batchWriteItem(request); - } catch (Exception ex) { - lastThrown = ex; - } - } while (canRetry(lastThrown) && BackOffUtils.next(Sleeper.DEFAULT, backoff)); - - DYNAMO_DB_WRITE_FAILURES.inc(); - LOG.warn(RETRY_ERROR_LOG, attempt, request.getRequestItems()); - throw new IOException( - canRetry(lastThrown) ? String.format(ERROR_RETRIES_EXCEEDED, attempt) : ERROR_NO_RETRY, - lastThrown); - } - - private boolean canRetry(Exception ex) { - return spec.getRetryConfiguration() != null - && spec.getRetryConfiguration().getRetryPredicate().test(ex); - } - - @Teardown - public void tearDown() { - if (client != null) { - client.shutdown(); - client = null; - } - } - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/package-info.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/package-info.java deleted file mode 100644 index 0a7ea559fb9b..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/package-info.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** Defines IO connectors for Amazon Web Services DynamoDB. */ -package org.apache.beam.sdk.io.aws.dynamodb; diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsModule.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsModule.java deleted file mode 100644 index 326758f1d1bb..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsModule.java +++ /dev/null @@ -1,390 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.options; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.BasicSessionCredentials; -import com.amazonaws.auth.ClasspathPropertiesFileCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.auth.EC2ContainerCredentialsProviderWrapper; -import com.amazonaws.auth.EnvironmentVariableCredentialsProvider; -import com.amazonaws.auth.PropertiesFileCredentialsProvider; -import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; -import com.amazonaws.auth.SystemPropertiesCredentialsProvider; -import com.amazonaws.auth.profile.ProfileCredentialsProvider; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import com.amazonaws.services.s3.model.SSECustomerKey; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.Module; -import com.fasterxml.jackson.databind.SerializerProvider; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.fasterxml.jackson.databind.jsontype.TypeDeserializer; -import com.fasterxml.jackson.databind.jsontype.TypeSerializer; -import com.fasterxml.jackson.databind.module.SimpleModule; -import com.google.auto.service.AutoService; -import java.io.IOException; -import java.util.Map; -import org.apache.beam.repackaged.core.org.apache.commons.lang3.reflect.FieldUtils; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; - -/** - * A Jackson {@link Module} that registers a {@link JsonSerializer} and {@link JsonDeserializer} for - * {@link AWSCredentialsProvider} and some subclasses. The serialized form is a JSON map. - * - *

It also adds serializers for S3 encryption objects {@link SSECustomerKey} and {@link - * SSEAwsKeyManagementParams}. - */ -@AutoService(Module.class) -public class AwsModule extends SimpleModule { - - private static final String AWS_ACCESS_KEY_ID = "awsAccessKeyId"; - private static final String AWS_SECRET_KEY = "awsSecretKey"; - private static final String SESSION_TOKEN = "sessionToken"; - private static final String CREDENTIALS_FILE_PATH = "credentialsFilePath"; - public static final String CLIENT_EXECUTION_TIMEOUT = "clientExecutionTimeout"; - public static final String CONNECTION_MAX_IDLE_TIME = "connectionMaxIdleTime"; - public static final String CONNECTION_TIMEOUT = "connectionTimeout"; - public static final String CONNECTION_TIME_TO_LIVE = "connectionTimeToLive"; - public static final String MAX_CONNECTIONS = "maxConnections"; - public static final String REQUEST_TIMEOUT = "requestTimeout"; - public static final String SOCKET_TIMEOUT = "socketTimeout"; - public static final String PROXY_HOST = "proxyHost"; - public static final String PROXY_PORT = "proxyPort"; - public static final String PROXY_USERNAME = "proxyUsername"; - public static final String PROXY_PASSWORD = "proxyPassword"; - private static final String ROLE_ARN = "roleArn"; - private static final String ROLE_SESSION_NAME = "roleSessionName"; - - @SuppressWarnings({"nullness"}) - public AwsModule() { - super("AwsModule"); - setMixInAnnotation(AWSCredentialsProvider.class, AWSCredentialsProviderMixin.class); - setMixInAnnotation(SSECustomerKey.class, SSECustomerKeyMixin.class); - setMixInAnnotation(SSEAwsKeyManagementParams.class, SSEAwsKeyManagementParamsMixin.class); - setMixInAnnotation(ClientConfiguration.class, AwsHttpClientConfigurationMixin.class); - } - - /** A mixin to add Jackson annotations to {@link AWSCredentialsProvider}. */ - @JsonDeserialize(using = AWSCredentialsProviderDeserializer.class) - @JsonSerialize(using = AWSCredentialsProviderSerializer.class) - @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY) - private static class AWSCredentialsProviderMixin {} - - private static class AWSCredentialsProviderDeserializer - extends JsonDeserializer { - - @Override - public AWSCredentialsProvider deserialize(JsonParser jsonParser, DeserializationContext context) - throws IOException { - return context.readValue(jsonParser, AWSCredentialsProvider.class); - } - - @Override - public AWSCredentialsProvider deserializeWithType( - JsonParser jsonParser, DeserializationContext context, TypeDeserializer typeDeserializer) - throws IOException { - Map asMap = - checkNotNull( - jsonParser.readValueAs(new TypeReference>() {}), - "Serialized AWS credentials provider is null"); - - String typeNameKey = typeDeserializer.getPropertyName(); - String typeName = getNotNull(asMap, typeNameKey, "unknown"); - - if (hasName(AWSStaticCredentialsProvider.class, typeName)) { - boolean isSession = asMap.containsKey(SESSION_TOKEN); - if (isSession) { - return new AWSStaticCredentialsProvider( - new BasicSessionCredentials( - getNotNull(asMap, AWS_ACCESS_KEY_ID, typeName), - getNotNull(asMap, AWS_SECRET_KEY, typeName), - getNotNull(asMap, SESSION_TOKEN, typeName))); - } else { - return new AWSStaticCredentialsProvider( - new BasicAWSCredentials( - getNotNull(asMap, AWS_ACCESS_KEY_ID, typeName), - getNotNull(asMap, AWS_SECRET_KEY, typeName))); - } - } else if (hasName(PropertiesFileCredentialsProvider.class, typeName)) { - return new PropertiesFileCredentialsProvider( - getNotNull(asMap, CREDENTIALS_FILE_PATH, typeName)); - } else if (hasName(ClasspathPropertiesFileCredentialsProvider.class, typeName)) { - return new ClasspathPropertiesFileCredentialsProvider( - getNotNull(asMap, CREDENTIALS_FILE_PATH, typeName)); - } else if (hasName(DefaultAWSCredentialsProviderChain.class, typeName)) { - return new DefaultAWSCredentialsProviderChain(); - } else if (hasName(EnvironmentVariableCredentialsProvider.class, typeName)) { - return new EnvironmentVariableCredentialsProvider(); - } else if (hasName(SystemPropertiesCredentialsProvider.class, typeName)) { - return new SystemPropertiesCredentialsProvider(); - } else if (hasName(ProfileCredentialsProvider.class, typeName)) { - return new ProfileCredentialsProvider(); - } else if (hasName(EC2ContainerCredentialsProviderWrapper.class, typeName)) { - return new EC2ContainerCredentialsProviderWrapper(); - } else if (hasName(STSAssumeRoleSessionCredentialsProvider.class, typeName)) { - return new STSAssumeRoleSessionCredentialsProvider.Builder( - getNotNull(asMap, ROLE_ARN, typeName), - getNotNull(asMap, ROLE_SESSION_NAME, typeName)) - .build(); - } else { - throw new IOException( - String.format("AWS credential provider type '%s' is not supported", typeName)); - } - } - - @SuppressWarnings({"nullness"}) - private String getNotNull(Map map, String key, String typeName) { - return checkNotNull( - map.get(key), "AWS credentials provider type '%s' is missing '%s'", typeName, key); - } - - private boolean hasName(Class clazz, String typeName) { - return clazz.getSimpleName().equals(typeName); - } - } - - private static class AWSCredentialsProviderSerializer - extends JsonSerializer { - // These providers are singletons, so don't require any serialization, other than type. - private static final ImmutableSet SINGLETON_CREDENTIAL_PROVIDERS = - ImmutableSet.of( - DefaultAWSCredentialsProviderChain.class, - EnvironmentVariableCredentialsProvider.class, - SystemPropertiesCredentialsProvider.class, - ProfileCredentialsProvider.class, - EC2ContainerCredentialsProviderWrapper.class); - - @Override - public void serialize( - AWSCredentialsProvider credentialsProvider, - JsonGenerator jsonGenerator, - SerializerProvider serializers) - throws IOException { - serializers.defaultSerializeValue(credentialsProvider, jsonGenerator); - } - - @Override - public void serializeWithType( - AWSCredentialsProvider credentialsProvider, - JsonGenerator jsonGenerator, - SerializerProvider serializers, - TypeSerializer typeSerializer) - throws IOException { - // BEAM-11958 Use deprecated Jackson APIs to be compatible with older versions of jackson - typeSerializer.writeTypePrefixForObject(credentialsProvider, jsonGenerator); - - Class providerClass = credentialsProvider.getClass(); - if (providerClass.equals(AWSStaticCredentialsProvider.class)) { - AWSCredentials credentials = credentialsProvider.getCredentials(); - if (credentials.getClass().equals(BasicSessionCredentials.class)) { - BasicSessionCredentials sessionCredentials = (BasicSessionCredentials) credentials; - jsonGenerator.writeStringField(AWS_ACCESS_KEY_ID, sessionCredentials.getAWSAccessKeyId()); - jsonGenerator.writeStringField(AWS_SECRET_KEY, sessionCredentials.getAWSSecretKey()); - jsonGenerator.writeStringField(SESSION_TOKEN, sessionCredentials.getSessionToken()); - } else { - jsonGenerator.writeStringField(AWS_ACCESS_KEY_ID, credentials.getAWSAccessKeyId()); - jsonGenerator.writeStringField(AWS_SECRET_KEY, credentials.getAWSSecretKey()); - } - } else if (providerClass.equals(PropertiesFileCredentialsProvider.class)) { - String filePath = (String) readField(credentialsProvider, CREDENTIALS_FILE_PATH); - jsonGenerator.writeStringField(CREDENTIALS_FILE_PATH, filePath); - } else if (providerClass.equals(ClasspathPropertiesFileCredentialsProvider.class)) { - String filePath = (String) readField(credentialsProvider, CREDENTIALS_FILE_PATH); - jsonGenerator.writeStringField(CREDENTIALS_FILE_PATH, filePath); - } else if (providerClass.equals(STSAssumeRoleSessionCredentialsProvider.class)) { - String arn = (String) readField(credentialsProvider, ROLE_ARN); - String sessionName = (String) readField(credentialsProvider, ROLE_SESSION_NAME); - jsonGenerator.writeStringField(ROLE_ARN, arn); - jsonGenerator.writeStringField(ROLE_SESSION_NAME, sessionName); - } else if (!SINGLETON_CREDENTIAL_PROVIDERS.contains(providerClass)) { - throw new IllegalArgumentException( - "Unsupported AWS credentials provider type " + providerClass); - } - // BEAM-11958 Use deprecated Jackson APIs to be compatible with older versions of jackson - typeSerializer.writeTypeSuffixForObject(credentialsProvider, jsonGenerator); - } - - private Object readField(AWSCredentialsProvider provider, String fieldName) throws IOException { - try { - return FieldUtils.readField(provider, fieldName, true); - } catch (IllegalArgumentException | IllegalAccessException e) { - throw new IOException( - String.format( - "Failed to access private field '%s' of AWS credential provider type '%s' with reflection", - fieldName, provider.getClass().getSimpleName()), - e); - } - } - } - - @SuppressWarnings({"nullness"}) - private static String getNotNull(Map map, String key, Class clazz) { - return checkNotNull(map.get(key), "`%s` required in serialized %s", key, clazz.getSimpleName()); - } - - /** A mixin to add Jackson annotations to {@link SSECustomerKey}. */ - @JsonDeserialize(using = SSECustomerKeyDeserializer.class) - private static class SSECustomerKeyMixin {} - - private static class SSECustomerKeyDeserializer extends JsonDeserializer { - @Override - public SSECustomerKey deserialize(JsonParser parser, DeserializationContext context) - throws IOException { - Map asMap = - checkNotNull( - parser.readValueAs(new TypeReference>() {}), - "Serialized SSECustomerKey is null"); - - SSECustomerKey sseCustomerKey = - new SSECustomerKey(getNotNull(asMap, "key", SSECustomerKey.class)); - final String algorithm = asMap.get("algorithm"); - final String md5 = asMap.get("md5"); - if (algorithm != null) { - sseCustomerKey.setAlgorithm(algorithm); - } - if (md5 != null) { - sseCustomerKey.setMd5(md5); - } - return sseCustomerKey; - } - } - - /** A mixin to add Jackson annotations to {@link SSEAwsKeyManagementParams}. */ - @JsonDeserialize(using = SSEAwsKeyManagementParamsDeserializer.class) - private static class SSEAwsKeyManagementParamsMixin {} - - private static class SSEAwsKeyManagementParamsDeserializer - extends JsonDeserializer { - @Override - public SSEAwsKeyManagementParams deserialize(JsonParser parser, DeserializationContext context) - throws IOException { - Map asMap = - checkNotNull( - parser.readValueAs(new TypeReference>() {}), - "Serialized SSEAwsKeyManagementParams is null"); - - return new SSEAwsKeyManagementParams( - getNotNull(asMap, "awsKmsKeyId", SSEAwsKeyManagementParams.class)); - } - } - - /** A mixin to add Jackson annotations to {@link ClientConfiguration}. */ - @JsonSerialize(using = AwsHttpClientConfigurationSerializer.class) - @JsonDeserialize(using = AwsHttpClientConfigurationDeserializer.class) - private static class AwsHttpClientConfigurationMixin {} - - private static class AwsHttpClientConfigurationDeserializer - extends JsonDeserializer { - @Override - public ClientConfiguration deserialize(JsonParser jsonParser, DeserializationContext context) - throws IOException { - Map map = - checkNotNull( - jsonParser.readValueAs(new TypeReference>() {}), - "Serialized ClientConfiguration is null"); - - ClientConfiguration clientConfiguration = new ClientConfiguration(); - - if (map.containsKey(PROXY_HOST)) { - clientConfiguration.setProxyHost((String) map.get(PROXY_HOST)); - } - if (map.containsKey(PROXY_PORT)) { - clientConfiguration.setProxyPort(((Number) map.get(PROXY_PORT)).intValue()); - } - if (map.containsKey(PROXY_USERNAME)) { - clientConfiguration.setProxyUsername((String) map.get(PROXY_USERNAME)); - } - if (map.containsKey(PROXY_PASSWORD)) { - clientConfiguration.setProxyPassword((String) map.get(PROXY_PASSWORD)); - } - if (map.containsKey(CLIENT_EXECUTION_TIMEOUT)) { - clientConfiguration.setClientExecutionTimeout( - ((Number) map.get(CLIENT_EXECUTION_TIMEOUT)).intValue()); - } - if (map.containsKey(CONNECTION_MAX_IDLE_TIME)) { - clientConfiguration.setConnectionMaxIdleMillis( - ((Number) map.get(CONNECTION_MAX_IDLE_TIME)).longValue()); - } - if (map.containsKey(CONNECTION_TIMEOUT)) { - clientConfiguration.setConnectionTimeout(((Number) map.get(CONNECTION_TIMEOUT)).intValue()); - } - if (map.containsKey(CONNECTION_TIME_TO_LIVE)) { - clientConfiguration.setConnectionTTL( - ((Number) map.get(CONNECTION_TIME_TO_LIVE)).longValue()); - } - if (map.containsKey(MAX_CONNECTIONS)) { - clientConfiguration.setMaxConnections(((Number) map.get(MAX_CONNECTIONS)).intValue()); - } - if (map.containsKey(REQUEST_TIMEOUT)) { - clientConfiguration.setRequestTimeout(((Number) map.get(REQUEST_TIMEOUT)).intValue()); - } - if (map.containsKey(SOCKET_TIMEOUT)) { - clientConfiguration.setSocketTimeout(((Number) map.get(SOCKET_TIMEOUT)).intValue()); - } - return clientConfiguration; - } - } - - private static class AwsHttpClientConfigurationSerializer - extends JsonSerializer { - - @Override - public void serialize( - ClientConfiguration clientConfiguration, - JsonGenerator jsonGenerator, - SerializerProvider serializer) - throws IOException { - - jsonGenerator.writeStartObject(); - jsonGenerator.writeObjectField(PROXY_HOST /*string*/, clientConfiguration.getProxyHost()); - jsonGenerator.writeObjectField(PROXY_PORT /*int*/, clientConfiguration.getProxyPort()); - jsonGenerator.writeObjectField( - PROXY_USERNAME /*string*/, clientConfiguration.getProxyUsername()); - jsonGenerator.writeObjectField( - PROXY_PASSWORD /*string*/, clientConfiguration.getProxyPassword()); - jsonGenerator.writeObjectField( - CLIENT_EXECUTION_TIMEOUT /*int*/, clientConfiguration.getClientExecutionTimeout()); - jsonGenerator.writeObjectField( - CONNECTION_MAX_IDLE_TIME /*long*/, clientConfiguration.getConnectionMaxIdleMillis()); - jsonGenerator.writeObjectField( - CONNECTION_TIMEOUT /*int*/, clientConfiguration.getConnectionTimeout()); - jsonGenerator.writeObjectField( - CONNECTION_TIME_TO_LIVE /*long*/, clientConfiguration.getConnectionTTL()); - jsonGenerator.writeObjectField( - MAX_CONNECTIONS /*int*/, clientConfiguration.getMaxConnections()); - jsonGenerator.writeObjectField( - REQUEST_TIMEOUT /*int*/, clientConfiguration.getRequestTimeout()); - jsonGenerator.writeObjectField( - SOCKET_TIMEOUT /*int*/, clientConfiguration.getSocketTimeout()); - jsonGenerator.writeEndObject(); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsOptions.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsOptions.java deleted file mode 100644 index 42e3a5614b09..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsOptions.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.options; - -import com.amazonaws.ClientConfiguration; -import com.amazonaws.SdkClientException; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.regions.DefaultAwsRegionProviderChain; -import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.DefaultValueFactory; -import org.apache.beam.sdk.options.Description; -import org.apache.beam.sdk.options.PipelineOptions; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** - * Options used to configure Amazon Web Services specific options such as credentials and region. - */ -public interface AwsOptions extends PipelineOptions { - - /** AWS region used by the AWS client. */ - @Description("AWS region used by the AWS client") - @Default.InstanceFactory(AwsRegionFactory.class) - String getAwsRegion(); - - void setAwsRegion(String value); - - /** Attempt to load default region. */ - class AwsRegionFactory implements DefaultValueFactory<@Nullable String> { - @Override - @Nullable - public String create(PipelineOptions options) { - try { - return new DefaultAwsRegionProviderChain().getRegion(); - } catch (SdkClientException e) { - return null; - } - } - } - - /** The AWS service endpoint used by the AWS client. */ - @Description("AWS service endpoint used by the AWS client") - String getAwsServiceEndpoint(); - - void setAwsServiceEndpoint(String value); - - /** - * The credential instance that should be used to authenticate against AWS services. The option - * value must contain a "@type" field and an AWS Credentials Provider class as the field value. - * Refer to {@link DefaultAWSCredentialsProviderChain} Javadoc for usage help. - * - *

For example, to specify the AWS key ID and secret, specify the following: - * {"@type" : "AWSStaticCredentialsProvider", "awsAccessKeyId" : "key_id_value", - * "awsSecretKey" : "secret_value"} - * - */ - @Description( - "The credential instance that should be used to authenticate " - + "against AWS services. The option value must contain \"@type\" field " - + "and an AWS Credentials Provider class name as the field value. " - + "Refer to DefaultAWSCredentialsProviderChain Javadoc for usage help. " - + "For example, to specify the AWS key ID and secret, specify the following: " - + "{\"@type\": \"AWSStaticCredentialsProvider\", " - + "\"awsAccessKeyId\":\"\", \"awsSecretKey\":\"\"}") - @Default.InstanceFactory(AwsUserCredentialsFactory.class) - AWSCredentialsProvider getAwsCredentialsProvider(); - - void setAwsCredentialsProvider(AWSCredentialsProvider value); - - /** Attempts to load AWS credentials. */ - class AwsUserCredentialsFactory implements DefaultValueFactory { - - @Override - public AWSCredentialsProvider create(PipelineOptions options) { - return DefaultAWSCredentialsProviderChain.getInstance(); - } - } - - /** - * The client configuration instance that should be used to configure AWS service clients. Please - * note that the configuration deserialization only allows one to specify proxy settings. Please - * use AwsHttpClientConfiguration's client configuration to set a wider range of options. - * - *

For example, to specify the proxy host, port, username and password, specify the following: - * - * --clientConfiguration={ - * "proxyHost":"hostname", - * "proxyPort":1234, - * "proxyUsername":"username", - * "proxyPassword":"password" - * } - * - * - * @return - */ - @Description( - "The client configuration instance that should be used to configure AWS service " - + "clients. Please note that the configuration deserialization only allows one to specify " - + "proxy settings. For example, to specify the proxy host, port, username and password, " - + "specify the following: --clientConfiguration={\"proxyHost\":\"hostname\",\"proxyPort\":1234," - + "\"proxyUsername\":\"username\",\"proxyPassword\":\"password\"}") - @Default.InstanceFactory(ClientConfigurationFactory.class) - ClientConfiguration getClientConfiguration(); - - void setClientConfiguration(ClientConfiguration clientConfiguration); - - /** Default AWS client configuration. */ - class ClientConfigurationFactory implements DefaultValueFactory { - - @Override - public ClientConfiguration create(PipelineOptions options) { - return new ClientConfiguration(); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/S3ClientBuilderFactory.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/S3ClientBuilderFactory.java deleted file mode 100644 index ce6eaa57cd8e..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/S3ClientBuilderFactory.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.options; - -import com.amazonaws.services.s3.AmazonS3ClientBuilder; - -/** Construct AmazonS3ClientBuilder from S3 pipeline options. */ -public interface S3ClientBuilderFactory { - AmazonS3ClientBuilder createBuilder(S3Options s3Options); -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/S3Options.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/S3Options.java deleted file mode 100644 index e9979b5c99ea..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/S3Options.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.options; - -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import com.amazonaws.services.s3.model.SSECustomerKey; -import org.apache.beam.sdk.io.aws.s3.DefaultS3ClientBuilderFactory; -import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.DefaultValueFactory; -import org.apache.beam.sdk.options.Description; -import org.apache.beam.sdk.options.PipelineOptions; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** Options used to configure Amazon Web Services S3. */ -public interface S3Options extends AwsOptions { - - @Description("AWS S3 storage class used for creating S3 objects") - @Default.String("STANDARD") - String getS3StorageClass(); - - void setS3StorageClass(String value); - - @Description( - "Size of S3 upload chunks; max upload object size is this value multiplied by 10000;" - + "default is 64MB, or 5MB in memory-constrained environments. Must be at least 5MB.") - @Default.InstanceFactory(S3UploadBufferSizeBytesFactory.class) - Integer getS3UploadBufferSizeBytes(); - - void setS3UploadBufferSizeBytes(Integer value); - - @Description("Thread pool size, limiting max concurrent S3 operations") - @Default.Integer(50) - int getS3ThreadPoolSize(); - - void setS3ThreadPoolSize(int value); - - @Description("Algorithm for SSE-S3 encryption, e.g. AES256.") - @Nullable - String getSSEAlgorithm(); - - void setSSEAlgorithm(String value); - - @Description( - "SSE key for SSE-C encryption, e.g. a base64 encoded key and the algorithm." - + "To specify on the command-line, represent the value as a JSON object. For example:" - + " --SSECustomerKey={\"key\": \"86glyTlCN...\", \"algorithm\": \"AES256\"}") - @Nullable - SSECustomerKey getSSECustomerKey(); - - void setSSECustomerKey(SSECustomerKey value); - - @Description( - "KMS key id for SSE-KMS encryption, e.g. \"arn:aws:kms:...\"." - + "To specify on the command-line, represent the value as a JSON object. For example:" - + " --SSEAwsKeyManagementParams={\"awsKmsKeyId\": \"arn:aws:kms:...\"}") - @Nullable - SSEAwsKeyManagementParams getSSEAwsKeyManagementParams(); - - void setSSEAwsKeyManagementParams(SSEAwsKeyManagementParams value); - - @Description( - "Set to true to use an S3 Bucket Key for object encryption with server-side " - + "encryption using AWS KMS (SSE-KMS)") - @Default.Boolean(false) - boolean getBucketKeyEnabled(); - - void setBucketKeyEnabled(boolean value); - - @Description( - "Factory class that should be created and used to create a builder of AmazonS3 client." - + "Override the default value if you need a S3 client with custom properties, like path style access, etc.") - @Default.Class(DefaultS3ClientBuilderFactory.class) - Class getS3ClientFactoryClass(); - - void setS3ClientFactoryClass(Class s3ClientFactoryClass); - - /** - * Provide the default s3 upload buffer size in bytes: 64MB if more than 512MB in RAM are - * available and 5MB otherwise. - */ - class S3UploadBufferSizeBytesFactory implements DefaultValueFactory { - public static final int MINIMUM_UPLOAD_BUFFER_SIZE_BYTES = 5_242_880; - - @Override - public Integer create(PipelineOptions options) { - return Runtime.getRuntime().maxMemory() < 536_870_912 - ? MINIMUM_UPLOAD_BUFFER_SIZE_BYTES - : 67_108_864; - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/package-info.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/package-info.java deleted file mode 100644 index fc79c546706a..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * Defines {@link org.apache.beam.sdk.options.PipelineOptions} for configuring pipeline execution - * for Amazon Web Services components. - */ -package org.apache.beam.sdk.io.aws.options; diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/DefaultS3ClientBuilderFactory.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/DefaultS3ClientBuilderFactory.java deleted file mode 100644 index fa96d79b63a7..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/DefaultS3ClientBuilderFactory.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import org.apache.beam.sdk.io.aws.options.S3ClientBuilderFactory; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; - -/** - * Construct AmazonS3ClientBuilder with default values of S3 client properties like path style - * access, accelerated mode, etc. - */ -public class DefaultS3ClientBuilderFactory implements S3ClientBuilderFactory { - - @Override - public AmazonS3ClientBuilder createBuilder(S3Options s3Options) { - AmazonS3ClientBuilder builder = - AmazonS3ClientBuilder.standard().withCredentials(s3Options.getAwsCredentialsProvider()); - - if (s3Options.getClientConfiguration() != null) { - builder = builder.withClientConfiguration(s3Options.getClientConfiguration()); - } - - if (!Strings.isNullOrEmpty(s3Options.getAwsServiceEndpoint())) { - builder = - builder.withEndpointConfiguration( - new AwsClientBuilder.EndpointConfiguration( - s3Options.getAwsServiceEndpoint(), s3Options.getAwsRegion())); - } else if (!Strings.isNullOrEmpty(s3Options.getAwsRegion())) { - builder = builder.withRegion(s3Options.getAwsRegion()); - } - return builder; - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/DefaultS3FileSystemSchemeRegistrar.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/DefaultS3FileSystemSchemeRegistrar.java deleted file mode 100644 index 0988309cb0e2..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/DefaultS3FileSystemSchemeRegistrar.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.google.auto.service.AutoService; -import javax.annotation.Nonnull; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; - -/** Registers the "s3" uri schema to be handled by {@link S3FileSystem}. */ -@AutoService(S3FileSystemSchemeRegistrar.class) -public class DefaultS3FileSystemSchemeRegistrar implements S3FileSystemSchemeRegistrar { - - @Override - public Iterable fromOptions(@Nonnull PipelineOptions options) { - checkNotNull(options, "Expect the runner have called FileSystems.setDefaultPipelineOptions()."); - return ImmutableList.of( - S3FileSystemConfiguration.fromS3Options(options.as(S3Options.class)).build()); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java deleted file mode 100644 index 75d66c46478a..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java +++ /dev/null @@ -1,671 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.sdk.io.FileSystemUtils.wildcardToRegexp; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - -import com.amazonaws.AmazonClientException; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.AmazonS3Exception; -import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest; -import com.amazonaws.services.s3.model.CompleteMultipartUploadResult; -import com.amazonaws.services.s3.model.CopyObjectRequest; -import com.amazonaws.services.s3.model.CopyObjectResult; -import com.amazonaws.services.s3.model.CopyPartRequest; -import com.amazonaws.services.s3.model.CopyPartResult; -import com.amazonaws.services.s3.model.DeleteObjectsRequest; -import com.amazonaws.services.s3.model.DeleteObjectsRequest.KeyVersion; -import com.amazonaws.services.s3.model.GetObjectMetadataRequest; -import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest; -import com.amazonaws.services.s3.model.InitiateMultipartUploadResult; -import com.amazonaws.services.s3.model.ListObjectsV2Request; -import com.amazonaws.services.s3.model.ListObjectsV2Result; -import com.amazonaws.services.s3.model.ObjectMetadata; -import com.amazonaws.services.s3.model.PartETag; -import com.amazonaws.services.s3.model.S3ObjectSummary; -import com.google.auto.value.AutoValue; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.WritableByteChannel; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Date; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Callable; -import java.util.concurrent.CompletionStage; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import org.apache.beam.sdk.io.FileSystem; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.apache.beam.sdk.io.fs.CreateOptions; -import org.apache.beam.sdk.io.fs.MatchResult; -import org.apache.beam.sdk.io.fs.MoveOptions; -import org.apache.beam.sdk.metrics.Lineage; -import org.apache.beam.sdk.util.MoreFutures; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListeningExecutorService; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * {@link FileSystem} implementation for storage systems that use the S3 protocol. - * - * @see S3FileSystemSchemeRegistrar - * @deprecated Module beam-sdks-java-io-amazon-web-services is deprecated and will be - * eventually removed. Please migrate to module beam-sdks-java-io-amazon-web-services2 - * . - */ -@Deprecated -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class S3FileSystem extends FileSystem { - - private static final Logger LOG = LoggerFactory.getLogger(S3FileSystem.class); - - // Amazon S3 API: You can create a copy of your object up to 5 GB in a single atomic operation - // Ref. https://docs.aws.amazon.com/AmazonS3/latest/dev/CopyingObjectsExamples.html - private static final long MAX_COPY_OBJECT_SIZE_BYTES = 5_368_709_120L; - - // S3 API, delete-objects: "You may specify up to 1000 keys." - private static final int MAX_DELETE_OBJECTS_PER_REQUEST = 1000; - - private static final ImmutableSet NON_READ_SEEK_EFFICIENT_ENCODINGS = - ImmutableSet.of("gzip"); - - // Non-final for testing. - private Supplier amazonS3; - private final S3FileSystemConfiguration config; - private final ListeningExecutorService executorService; - - S3FileSystem(S3FileSystemConfiguration config) { - this.config = checkNotNull(config, "config"); - // The Supplier is to make sure we don't call .build() unless we are actually using S3. - amazonS3 = Suppliers.memoize(config.getS3ClientBuilder()::build); - - checkNotNull(config.getS3StorageClass(), "storageClass"); - checkArgument(config.getS3ThreadPoolSize() > 0, "threadPoolSize"); - executorService = - MoreExecutors.listeningDecorator( - Executors.newFixedThreadPool( - config.getS3ThreadPoolSize(), new ThreadFactoryBuilder().setDaemon(true).build())); - - LOG.warn( - "You are using a deprecated file system for S3. Please migrate to module " - + "'org.apache.beam:beam-sdks-java-io-amazon-web-services2'."); - } - - S3FileSystem(S3Options options) { - this(S3FileSystemConfiguration.fromS3Options(options).build()); - } - - @Override - protected String getScheme() { - return config.getScheme(); - } - - @VisibleForTesting - void setAmazonS3Client(AmazonS3 amazonS3) { - this.amazonS3 = Suppliers.ofInstance(amazonS3); - } - - @VisibleForTesting - AmazonS3 getAmazonS3Client() { - return this.amazonS3.get(); - } - - @Override - protected List match(List specs) throws IOException { - List paths = - specs.stream().map(S3ResourceId::fromUri).collect(Collectors.toList()); - List globs = new ArrayList<>(); - List nonGlobs = new ArrayList<>(); - List isGlobBooleans = new ArrayList<>(); - - for (S3ResourceId path : paths) { - if (path.isWildcard()) { - globs.add(path); - isGlobBooleans.add(true); - } else { - nonGlobs.add(path); - isGlobBooleans.add(false); - } - } - - Iterator globMatches = matchGlobPaths(globs).iterator(); - Iterator nonGlobMatches = matchNonGlobPaths(nonGlobs).iterator(); - - ImmutableList.Builder matchResults = ImmutableList.builder(); - for (Boolean isGlob : isGlobBooleans) { - if (isGlob) { - checkState( - globMatches.hasNext(), - "Internal error encountered in S3Filesystem: expected more elements in globMatches."); - matchResults.add(globMatches.next()); - } else { - checkState( - nonGlobMatches.hasNext(), - "Internal error encountered in S3Filesystem: expected more elements in nonGlobMatches."); - matchResults.add(nonGlobMatches.next()); - } - } - checkState( - !globMatches.hasNext(), - "Internal error encountered in S3Filesystem: expected no more elements in globMatches."); - checkState( - !nonGlobMatches.hasNext(), - "Internal error encountered in S3Filesystem: expected no more elements in nonGlobMatches."); - - return matchResults.build(); - } - - /** Gets {@link MatchResult} representing all objects that match wildcard-containing paths. */ - @VisibleForTesting - List matchGlobPaths(Collection globPaths) throws IOException { - List> expandTasks = new ArrayList<>(globPaths.size()); - for (final S3ResourceId path : globPaths) { - expandTasks.add(() -> expandGlob(path)); - } - - Map expandedGlobByGlobPath = new HashMap<>(); - List> contentTypeTasks = new ArrayList<>(globPaths.size()); - for (ExpandedGlob expandedGlob : callTasks(expandTasks)) { - expandedGlobByGlobPath.put(expandedGlob.getGlobPath(), expandedGlob); - if (expandedGlob.getExpandedPaths() != null) { - for (final S3ResourceId path : expandedGlob.getExpandedPaths()) { - contentTypeTasks.add(() -> getPathContentEncoding(path)); - } - } - } - - Map exceptionByPath = new HashMap<>(); - for (PathWithEncoding pathWithException : callTasks(contentTypeTasks)) { - exceptionByPath.put(pathWithException.getPath(), pathWithException); - } - - List results = new ArrayList<>(globPaths.size()); - for (S3ResourceId globPath : globPaths) { - ExpandedGlob expandedGlob = expandedGlobByGlobPath.get(globPath); - - if (expandedGlob.getException() != null) { - results.add(MatchResult.create(MatchResult.Status.ERROR, expandedGlob.getException())); - - } else { - List metadatas = new ArrayList<>(); - IOException exception = null; - for (S3ResourceId expandedPath : expandedGlob.getExpandedPaths()) { - PathWithEncoding pathWithEncoding = exceptionByPath.get(expandedPath); - - if (pathWithEncoding.getException() != null) { - exception = pathWithEncoding.getException(); - break; - } else { - // TODO(https://github.com/apache/beam/issues/20755): Support file checksum in this - // method. - metadatas.add( - createBeamMetadata( - pathWithEncoding.getPath(), pathWithEncoding.getContentEncoding(), null)); - } - } - - if (exception != null) { - if (exception instanceof FileNotFoundException) { - results.add(MatchResult.create(MatchResult.Status.NOT_FOUND, exception)); - } else { - results.add(MatchResult.create(MatchResult.Status.ERROR, exception)); - } - } else { - results.add(MatchResult.create(MatchResult.Status.OK, metadatas)); - } - } - } - - return ImmutableList.copyOf(results); - } - - @AutoValue - abstract static class ExpandedGlob { - - abstract S3ResourceId getGlobPath(); - - abstract @Nullable List getExpandedPaths(); - - abstract @Nullable IOException getException(); - - static ExpandedGlob create(S3ResourceId globPath, List expandedPaths) { - checkNotNull(globPath, "globPath"); - checkNotNull(expandedPaths, "expandedPaths"); - return new AutoValue_S3FileSystem_ExpandedGlob(globPath, expandedPaths, null); - } - - static ExpandedGlob create(S3ResourceId globPath, IOException exception) { - checkNotNull(globPath, "globPath"); - checkNotNull(exception, "exception"); - return new AutoValue_S3FileSystem_ExpandedGlob(globPath, null, exception); - } - } - - @AutoValue - abstract static class PathWithEncoding { - - abstract S3ResourceId getPath(); - - abstract @Nullable String getContentEncoding(); - - abstract @Nullable IOException getException(); - - static PathWithEncoding create(S3ResourceId path, String contentEncoding) { - checkNotNull(path, "path"); - checkNotNull(contentEncoding, "contentEncoding"); - return new AutoValue_S3FileSystem_PathWithEncoding(path, contentEncoding, null); - } - - static PathWithEncoding create(S3ResourceId path, IOException exception) { - checkNotNull(path, "path"); - checkNotNull(exception, "exception"); - return new AutoValue_S3FileSystem_PathWithEncoding(path, null, exception); - } - } - - private ExpandedGlob expandGlob(S3ResourceId glob) { - // The S3 API can list objects, filtered by prefix, but not by wildcard. - // Here, we find the longest prefix without wildcard "*", - // then filter the results with a regex. - checkArgument(glob.isWildcard(), "isWildcard"); - String keyPrefix = glob.getKeyNonWildcardPrefix(); - Pattern wildcardRegexp = Pattern.compile(wildcardToRegexp(glob.getKey())); - - LOG.debug( - "expanding bucket {}, prefix {}, against pattern {}", - glob.getBucket(), - keyPrefix, - wildcardRegexp.toString()); - - ImmutableList.Builder expandedPaths = ImmutableList.builder(); - String continuationToken = null; - - do { - ListObjectsV2Request request = - new ListObjectsV2Request() - .withBucketName(glob.getBucket()) - .withPrefix(keyPrefix) - .withContinuationToken(continuationToken); - ListObjectsV2Result result; - try { - result = amazonS3.get().listObjectsV2(request); - } catch (AmazonClientException e) { - return ExpandedGlob.create(glob, new IOException(e)); - } - continuationToken = result.getNextContinuationToken(); - - for (S3ObjectSummary objectSummary : result.getObjectSummaries()) { - // Filter against regex. - if (wildcardRegexp.matcher(objectSummary.getKey()).matches()) { - S3ResourceId expandedPath = - S3ResourceId.fromComponents( - glob.getScheme(), objectSummary.getBucketName(), objectSummary.getKey()) - .withSize(objectSummary.getSize()) - .withLastModified(objectSummary.getLastModified()); - LOG.debug("Expanded S3 object path {}", expandedPath); - expandedPaths.add(expandedPath); - } - } - } while (continuationToken != null); - - return ExpandedGlob.create(glob, expandedPaths.build()); - } - - private PathWithEncoding getPathContentEncoding(S3ResourceId path) { - ObjectMetadata s3Metadata; - try { - s3Metadata = getObjectMetadata(path); - } catch (AmazonClientException e) { - if (e instanceof AmazonS3Exception && ((AmazonS3Exception) e).getStatusCode() == 404) { - return PathWithEncoding.create(path, new FileNotFoundException()); - } - return PathWithEncoding.create(path, new IOException(e)); - } - return PathWithEncoding.create(path, Strings.nullToEmpty(s3Metadata.getContentEncoding())); - } - - private List matchNonGlobPaths(Collection paths) throws IOException { - List> tasks = new ArrayList<>(paths.size()); - for (final S3ResourceId path : paths) { - tasks.add(() -> matchNonGlobPath(path)); - } - - return callTasks(tasks); - } - - private ObjectMetadata getObjectMetadata(S3ResourceId s3ResourceId) throws AmazonClientException { - GetObjectMetadataRequest request = - new GetObjectMetadataRequest(s3ResourceId.getBucket(), s3ResourceId.getKey()); - request.setSSECustomerKey(config.getSSECustomerKey()); - return amazonS3.get().getObjectMetadata(request); - } - - @VisibleForTesting - MatchResult matchNonGlobPath(S3ResourceId path) { - ObjectMetadata s3Metadata; - try { - s3Metadata = getObjectMetadata(path); - } catch (AmazonClientException e) { - if (e instanceof AmazonS3Exception && ((AmazonS3Exception) e).getStatusCode() == 404) { - return MatchResult.create(MatchResult.Status.NOT_FOUND, new FileNotFoundException()); - } - return MatchResult.create(MatchResult.Status.ERROR, new IOException(e)); - } - - return MatchResult.create( - MatchResult.Status.OK, - ImmutableList.of( - createBeamMetadata( - path.withSize(s3Metadata.getContentLength()) - .withLastModified(s3Metadata.getLastModified()), - Strings.nullToEmpty(s3Metadata.getContentEncoding()), - s3Metadata.getETag()))); - } - - private static MatchResult.Metadata createBeamMetadata( - S3ResourceId path, String contentEncoding, String eTag) { - checkArgument(path.getSize().isPresent(), "The resource id should have a size."); - checkNotNull(contentEncoding, "contentEncoding"); - boolean isReadSeekEfficient = !NON_READ_SEEK_EFFICIENT_ENCODINGS.contains(contentEncoding); - - MatchResult.Metadata.Builder ret = - MatchResult.Metadata.builder() - .setIsReadSeekEfficient(isReadSeekEfficient) - .setResourceId(path) - .setSizeBytes(path.getSize().get()) - .setLastModifiedMillis(path.getLastModified().transform(Date::getTime).or(0L)); - if (eTag != null) { - ret.setChecksum(eTag); - } - return ret.build(); - } - - @Override - protected WritableByteChannel create(S3ResourceId resourceId, CreateOptions createOptions) - throws IOException { - return new S3WritableByteChannel(amazonS3.get(), resourceId, createOptions.mimeType(), config); - } - - @Override - protected ReadableByteChannel open(S3ResourceId resourceId) throws IOException { - return new S3ReadableSeekableByteChannel(amazonS3.get(), resourceId, config); - } - - @Override - protected void copy(List sourcePaths, List destinationPaths) - throws IOException { - checkArgument( - sourcePaths.size() == destinationPaths.size(), - "sizes of sourcePaths and destinationPaths do not match"); - - List> tasks = new ArrayList<>(sourcePaths.size()); - - Iterator sourcePathsIterator = sourcePaths.iterator(); - Iterator destinationPathsIterator = destinationPaths.iterator(); - while (sourcePathsIterator.hasNext()) { - final S3ResourceId sourcePath = sourcePathsIterator.next(); - final S3ResourceId destinationPath = destinationPathsIterator.next(); - - tasks.add( - () -> { - copy(sourcePath, destinationPath); - return null; - }); - } - - callTasks(tasks); - } - - @VisibleForTesting - void copy(S3ResourceId sourcePath, S3ResourceId destinationPath) throws IOException { - try { - ObjectMetadata sourceObjectMetadata = getObjectMetadata(sourcePath); - if (sourceObjectMetadata.getContentLength() < MAX_COPY_OBJECT_SIZE_BYTES) { - atomicCopy(sourcePath, destinationPath, sourceObjectMetadata); - } else { - multipartCopy(sourcePath, destinationPath, sourceObjectMetadata); - } - } catch (AmazonClientException e) { - throw new IOException(e); - } - } - - @VisibleForTesting - CopyObjectResult atomicCopy( - S3ResourceId sourcePath, S3ResourceId destinationPath, ObjectMetadata sourceObjectMetadata) - throws AmazonClientException { - CopyObjectRequest copyObjectRequest = - new CopyObjectRequest( - sourcePath.getBucket(), - sourcePath.getKey(), - destinationPath.getBucket(), - destinationPath.getKey()); - copyObjectRequest.setNewObjectMetadata(sourceObjectMetadata); - copyObjectRequest.setStorageClass(config.getS3StorageClass()); - copyObjectRequest.setSourceSSECustomerKey(config.getSSECustomerKey()); - copyObjectRequest.setDestinationSSECustomerKey(config.getSSECustomerKey()); - return amazonS3.get().copyObject(copyObjectRequest); - } - - @VisibleForTesting - CompleteMultipartUploadResult multipartCopy( - S3ResourceId sourcePath, S3ResourceId destinationPath, ObjectMetadata sourceObjectMetadata) - throws AmazonClientException { - InitiateMultipartUploadRequest initiateUploadRequest = - new InitiateMultipartUploadRequest(destinationPath.getBucket(), destinationPath.getKey()) - .withStorageClass(config.getS3StorageClass()) - .withObjectMetadata(sourceObjectMetadata) - .withSSECustomerKey(config.getSSECustomerKey()); - - InitiateMultipartUploadResult initiateUploadResult = - amazonS3.get().initiateMultipartUpload(initiateUploadRequest); - final String uploadId = initiateUploadResult.getUploadId(); - - List eTags = new ArrayList<>(); - - final long objectSize = sourceObjectMetadata.getContentLength(); - // extra validation in case a caller calls directly S3FileSystem.multipartCopy - // without using S3FileSystem.copy in the future - if (objectSize == 0) { - final CopyPartRequest copyPartRequest = - new CopyPartRequest() - .withSourceBucketName(sourcePath.getBucket()) - .withSourceKey(sourcePath.getKey()) - .withDestinationBucketName(destinationPath.getBucket()) - .withDestinationKey(destinationPath.getKey()) - .withUploadId(uploadId) - .withPartNumber(1); - copyPartRequest.setSourceSSECustomerKey(config.getSSECustomerKey()); - copyPartRequest.setDestinationSSECustomerKey(config.getSSECustomerKey()); - - CopyPartResult copyPartResult = amazonS3.get().copyPart(copyPartRequest); - eTags.add(copyPartResult.getPartETag()); - } else { - long bytePosition = 0; - // Amazon parts are 1-indexed, not zero-indexed. - for (int partNumber = 1; bytePosition < objectSize; partNumber++) { - final CopyPartRequest copyPartRequest = - new CopyPartRequest() - .withSourceBucketName(sourcePath.getBucket()) - .withSourceKey(sourcePath.getKey()) - .withDestinationBucketName(destinationPath.getBucket()) - .withDestinationKey(destinationPath.getKey()) - .withUploadId(uploadId) - .withPartNumber(partNumber) - .withFirstByte(bytePosition) - .withLastByte( - Math.min(objectSize - 1, bytePosition + MAX_COPY_OBJECT_SIZE_BYTES - 1)); - copyPartRequest.setSourceSSECustomerKey(config.getSSECustomerKey()); - copyPartRequest.setDestinationSSECustomerKey(config.getSSECustomerKey()); - - CopyPartResult copyPartResult = amazonS3.get().copyPart(copyPartRequest); - eTags.add(copyPartResult.getPartETag()); - - bytePosition += MAX_COPY_OBJECT_SIZE_BYTES; - } - } - - CompleteMultipartUploadRequest completeUploadRequest = - new CompleteMultipartUploadRequest() - .withBucketName(destinationPath.getBucket()) - .withKey(destinationPath.getKey()) - .withUploadId(uploadId) - .withPartETags(eTags); - return amazonS3.get().completeMultipartUpload(completeUploadRequest); - } - - @Override - protected void rename( - List sourceResourceIds, - List destinationResourceIds, - MoveOptions... moveOptions) - throws IOException { - if (moveOptions.length > 0) { - throw new UnsupportedOperationException("Support for move options is not yet implemented."); - } - copy(sourceResourceIds, destinationResourceIds); - delete(sourceResourceIds); - } - - @Override - protected void delete(Collection resourceIds) throws IOException { - List nonDirectoryPaths = - resourceIds.stream() - .filter(s3ResourceId -> !s3ResourceId.isDirectory()) - .collect(Collectors.toList()); - Multimap keysByBucket = ArrayListMultimap.create(); - for (S3ResourceId path : nonDirectoryPaths) { - keysByBucket.put(path.getBucket(), path.getKey()); - } - - List> tasks = new ArrayList<>(); - for (final String bucket : keysByBucket.keySet()) { - for (final List keysPartition : - Iterables.partition(keysByBucket.get(bucket), MAX_DELETE_OBJECTS_PER_REQUEST)) { - tasks.add( - () -> { - delete(bucket, keysPartition); - return null; - }); - } - } - - callTasks(tasks); - } - - private void delete(String bucket, Collection keys) throws IOException { - checkArgument( - keys.size() <= MAX_DELETE_OBJECTS_PER_REQUEST, - "only %s keys can be deleted per request, but got %s", - MAX_DELETE_OBJECTS_PER_REQUEST, - keys.size()); - List deleteKeyVersions = - keys.stream().map(KeyVersion::new).collect(Collectors.toList()); - DeleteObjectsRequest request = - new DeleteObjectsRequest(bucket).withKeys(deleteKeyVersions).withQuiet(true); - try { - amazonS3.get().deleteObjects(request); - } catch (AmazonClientException e) { - throw new IOException(e); - } - } - - @Override - protected S3ResourceId matchNewResource(String singleResourceSpec, boolean isDirectory) { - if (isDirectory) { - if (!singleResourceSpec.endsWith("/")) { - singleResourceSpec += "/"; - } - } else { - checkArgument( - !singleResourceSpec.endsWith("/"), - "Expected a file path, but [%s] ends with '/'. This is unsupported in S3FileSystem.", - singleResourceSpec); - } - return S3ResourceId.fromUri(singleResourceSpec); - } - - @Override - protected void reportLineage(S3ResourceId resourceId, Lineage lineage) { - reportLineage(resourceId, lineage, LineageLevel.FILE); - } - - @Override - protected void reportLineage(S3ResourceId resourceId, Lineage lineage, LineageLevel level) { - ImmutableList.Builder segments = - ImmutableList.builder().add(resourceId.getBucket()); - if (level != LineageLevel.TOP_LEVEL && !resourceId.getKey().isEmpty()) { - segments.add(resourceId.getKey()); - } - lineage.add("s3", segments.build()); - } - - /** - * Invokes tasks in a thread pool, then unwraps the resulting {@link Future Futures}. - * - *

Any task exception is wrapped in {@link IOException}. - */ - private List callTasks(Collection> tasks) throws IOException { - - try { - List> futures = new ArrayList<>(tasks.size()); - for (Callable task : tasks) { - futures.add(MoreFutures.supplyAsync(task::call, executorService)); - } - return MoreFutures.get(MoreFutures.allAsList(futures)); - - } catch (ExecutionException e) { - if (e.getCause() != null) { - if (e.getCause() instanceof IOException) { - throw (IOException) e.getCause(); - } - throw new IOException(e.getCause()); - } - throw new IOException(e); - - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("executor service was interrupted"); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemConfiguration.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemConfiguration.java deleted file mode 100644 index 248f99aa0651..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemConfiguration.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import com.amazonaws.services.s3.model.SSECustomerKey; -import com.google.auto.value.AutoValue; -import javax.annotation.Nullable; -import org.apache.beam.sdk.io.aws.options.S3ClientBuilderFactory; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.apache.beam.sdk.util.InstanceBuilder; - -/** - * Object used to configure {@link S3FileSystem}. - * - * @see S3Options - * @see S3FileSystemSchemeRegistrar - */ -@AutoValue -public abstract class S3FileSystemConfiguration { - public static final int MINIMUM_UPLOAD_BUFFER_SIZE_BYTES = - S3Options.S3UploadBufferSizeBytesFactory.MINIMUM_UPLOAD_BUFFER_SIZE_BYTES; - - /** The uri scheme used by resources on this filesystem. */ - public abstract String getScheme(); - - /** The AWS S3 storage class used for creating S3 objects. */ - public abstract String getS3StorageClass(); - - /** Size of S3 upload chunks. */ - public abstract int getS3UploadBufferSizeBytes(); - - /** Thread pool size, limiting the max concurrent S3 operations. */ - public abstract int getS3ThreadPoolSize(); - - /** Algorithm for SSE-S3 encryption, e.g. AES256. */ - public abstract @Nullable String getSSEAlgorithm(); - - /** SSE key for SSE-C encryption, e.g. a base64 encoded key and the algorithm. */ - public abstract @Nullable SSECustomerKey getSSECustomerKey(); - - /** KMS key id for SSE-KMS encryption, e.g. "arn:aws:kms:...". */ - public abstract @Nullable SSEAwsKeyManagementParams getSSEAwsKeyManagementParams(); - - /** - * Whether to ose an S3 Bucket Key for object encryption with server-side encryption using AWS KMS - * (SSE-KMS) or not. - */ - public abstract boolean getBucketKeyEnabled(); - - /** Builder used to create the {@code AmazonS3Client}. */ - public abstract AmazonS3ClientBuilder getS3ClientBuilder(); - - /** Creates a new uninitialized {@link Builder}. */ - public static Builder builder() { - return new AutoValue_S3FileSystemConfiguration.Builder(); - } - - /** Creates a new {@link Builder} with values initialized by this instance's properties. */ - public abstract Builder toBuilder(); - - /** - * Creates a new {@link Builder} with values initialized by the properties of {@code s3Options}. - */ - public static Builder fromS3Options(S3Options s3Options) { - return builder() - .setScheme("s3") - .setS3StorageClass(s3Options.getS3StorageClass()) - .setS3UploadBufferSizeBytes(s3Options.getS3UploadBufferSizeBytes()) - .setS3ThreadPoolSize(s3Options.getS3ThreadPoolSize()) - .setSSEAlgorithm(s3Options.getSSEAlgorithm()) - .setSSECustomerKey(s3Options.getSSECustomerKey()) - .setSSEAwsKeyManagementParams(s3Options.getSSEAwsKeyManagementParams()) - .setBucketKeyEnabled(s3Options.getBucketKeyEnabled()) - .setS3ClientBuilder(getBuilder(s3Options)); - } - - /** Creates a new {@link AmazonS3ClientBuilder} as specified by {@code s3Options}. */ - public static AmazonS3ClientBuilder getBuilder(S3Options s3Options) { - return InstanceBuilder.ofType(S3ClientBuilderFactory.class) - .fromClass(s3Options.getS3ClientFactoryClass()) - .build() - .createBuilder(s3Options); - } - - @AutoValue.Builder - public abstract static class Builder { - public abstract Builder setScheme(String value); - - public abstract Builder setS3StorageClass(String value); - - public abstract Builder setS3UploadBufferSizeBytes(int value); - - public abstract Builder setS3ThreadPoolSize(int value); - - public abstract Builder setSSEAlgorithm(@Nullable String value); - - public abstract Builder setSSECustomerKey(@Nullable SSECustomerKey value); - - public abstract Builder setSSEAwsKeyManagementParams(@Nullable SSEAwsKeyManagementParams value); - - public abstract Builder setBucketKeyEnabled(boolean value); - - public abstract Builder setS3ClientBuilder(AmazonS3ClientBuilder value); - - public abstract S3FileSystemConfiguration build(); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemRegistrar.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemRegistrar.java deleted file mode 100644 index af153de42622..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemRegistrar.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.google.auto.service.AutoService; -import java.util.Map; -import java.util.ServiceLoader; -import java.util.stream.Collectors; -import javax.annotation.Nonnull; -import org.apache.beam.sdk.io.FileSystem; -import org.apache.beam.sdk.io.FileSystemRegistrar; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.util.common.ReflectHelpers; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; - -/** - * {@link AutoService} registrar for the {@link S3FileSystem}. - * - *

Creates instances of {@link S3FileSystem} for each scheme registered with a {@link - * S3FileSystemSchemeRegistrar}. - */ -@AutoService(FileSystemRegistrar.class) -public class S3FileSystemRegistrar implements FileSystemRegistrar { - - @Override - public Iterable> fromOptions(@Nonnull PipelineOptions options) { - checkNotNull(options, "Expect the runner have called FileSystems.setDefaultPipelineOptions()."); - Map> fileSystems = - Streams.stream( - ServiceLoader.load( - S3FileSystemSchemeRegistrar.class, ReflectHelpers.findClassLoader())) - .flatMap(r -> Streams.stream(r.fromOptions(options))) - .map(S3FileSystem::new) - // Throws IllegalStateException if any duplicate schemes exist. - .collect(Collectors.toMap(S3FileSystem::getScheme, f -> (FileSystem) f)); - return fileSystems.values(); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemSchemeRegistrar.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemSchemeRegistrar.java deleted file mode 100644 index 191b6f2cd244..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemSchemeRegistrar.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import com.google.auto.service.AutoService; -import java.util.ServiceLoader; -import javax.annotation.Nonnull; -import org.apache.beam.sdk.io.FileSystem; -import org.apache.beam.sdk.io.FileSystemRegistrar; -import org.apache.beam.sdk.options.PipelineOptions; - -/** - * A registrar that creates {@link S3FileSystemConfiguration} instances from {@link - * PipelineOptions}. - * - *

Users of storage systems that use the S3 protocol have the ability to register a URI scheme by - * creating a {@link ServiceLoader} entry and a concrete implementation of this interface. - * - *

It is optional but recommended to use one of the many build time tools such as {@link - * AutoService} to generate the necessary META-INF files automatically. - */ -public interface S3FileSystemSchemeRegistrar { - /** - * Create zero or more {@link S3FileSystemConfiguration} instances from the given {@link - * PipelineOptions}. - * - *

Each {@link S3FileSystemConfiguration#getScheme() scheme} is required to be unique among all - * schemes registered by all {@link S3FileSystemSchemeRegistrar}s, as well as among all {@link - * FileSystem}s registered by all {@link FileSystemRegistrar}s. - */ - Iterable fromOptions(@Nonnull PipelineOptions options); -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3ReadableSeekableByteChannel.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3ReadableSeekableByteChannel.java deleted file mode 100644 index bef1fc340888..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3ReadableSeekableByteChannel.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static com.amazonaws.util.IOUtils.drainInputStream; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.AmazonClientException; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.GetObjectRequest; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import java.io.BufferedInputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.Channels; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.NonWritableChannelException; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.SeekableByteChannel; - -/** A readable S3 object, as a {@link SeekableByteChannel}. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class S3ReadableSeekableByteChannel implements SeekableByteChannel { - - private final AmazonS3 amazonS3; - private final S3ResourceId path; - private final long contentLength; - private long position = 0; - private boolean open = true; - private S3Object s3Object; - private final S3FileSystemConfiguration config; - private ReadableByteChannel s3ObjectContentChannel; - - S3ReadableSeekableByteChannel( - AmazonS3 amazonS3, S3ResourceId path, S3FileSystemConfiguration config) throws IOException { - this.amazonS3 = checkNotNull(amazonS3, "amazonS3"); - checkNotNull(path, "path"); - this.config = checkNotNull(config, "config"); - - if (path.getSize().isPresent()) { - contentLength = path.getSize().get(); - this.path = path; - - } else { - try { - contentLength = - amazonS3.getObjectMetadata(path.getBucket(), path.getKey()).getContentLength(); - } catch (AmazonClientException e) { - throw new IOException(e); - } - this.path = path.withSize(contentLength); - } - } - - @Override - public int read(ByteBuffer destinationBuffer) throws IOException { - if (!isOpen()) { - throw new ClosedChannelException(); - } - if (!destinationBuffer.hasRemaining()) { - return 0; - } - if (position == contentLength) { - return -1; - } - - if (s3Object == null) { - GetObjectRequest request = new GetObjectRequest(path.getBucket(), path.getKey()); - request.setSSECustomerKey(config.getSSECustomerKey()); - if (position > 0) { - request.setRange(position, contentLength); - } - try { - s3Object = amazonS3.getObject(request); - } catch (AmazonClientException e) { - throw new IOException(e); - } - s3ObjectContentChannel = - Channels.newChannel(new BufferedInputStream(s3Object.getObjectContent(), 1024 * 1024)); - } - - int totalBytesRead = 0; - int bytesRead = 0; - - do { - totalBytesRead += bytesRead; - try { - bytesRead = s3ObjectContentChannel.read(destinationBuffer); - } catch (AmazonClientException e) { - // TODO replace all catch AmazonServiceException with client exception - throw new IOException(e); - } - } while (bytesRead > 0); - - position += totalBytesRead; - return totalBytesRead; - } - - @Override - public long position() throws ClosedChannelException { - if (!isOpen()) { - throw new ClosedChannelException(); - } - return position; - } - - @Override - public SeekableByteChannel position(long newPosition) throws IOException { - if (!isOpen()) { - throw new ClosedChannelException(); - } - checkArgument(newPosition >= 0, "newPosition too low"); - checkArgument(newPosition < contentLength, "new position too high"); - - if (newPosition == position) { - return this; - } - - // The position has changed, so close and destroy the object to induce a re-creation on the next - // call to read() - if (s3Object != null) { - s3Object.close(); - s3Object = null; - } - position = newPosition; - return this; - } - - @Override - public long size() throws ClosedChannelException { - if (!isOpen()) { - throw new ClosedChannelException(); - } - return contentLength; - } - - @Override - public void close() throws IOException { - if (s3Object != null) { - S3ObjectInputStream s3ObjectInputStream = s3Object.getObjectContent(); - drainInputStream(s3ObjectInputStream); - s3Object.close(); - } - open = false; - } - - @Override - public boolean isOpen() { - return open; - } - - @Override - public int write(ByteBuffer src) { - throw new NonWritableChannelException(); - } - - @Override - public SeekableByteChannel truncate(long size) { - throw new NonWritableChannelException(); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3ResourceId.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3ResourceId.java deleted file mode 100644 index 2751f98d7f6e..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3ResourceId.java +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - -import java.io.ObjectStreamException; -import java.util.Date; -import java.util.Objects; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import org.apache.beam.sdk.io.fs.ResolveOptions; -import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; -import org.apache.beam.sdk.io.fs.ResourceId; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.checkerframework.checker.nullness.qual.Nullable; - -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class S3ResourceId implements ResourceId { - - private static final long serialVersionUID = -8218379666994031337L; - - static final String DEFAULT_SCHEME = "s3"; - - private static final Pattern S3_URI = - Pattern.compile("(?[^:]+)://(?[^/]+)(/(?.*))?"); - - /** Matches a glob containing a wildcard, capturing the portion before the first wildcard. */ - private static final Pattern GLOB_PREFIX = Pattern.compile("(?[^\\[*?]*)[\\[*?].*"); - - private final String bucket; - private final String key; - private final Long size; - private final Date lastModified; - private final String scheme; - - private S3ResourceId( - String scheme, String bucket, String key, @Nullable Long size, @Nullable Date lastModified) { - checkArgument(!Strings.isNullOrEmpty(scheme), "scheme"); - checkArgument(!Strings.isNullOrEmpty(bucket), "bucket"); - checkArgument(!bucket.contains("/"), "bucket must not contain '/': [%s]", bucket); - this.scheme = scheme; - this.bucket = bucket; - this.key = checkNotNull(key, "key"); - this.size = size; - this.lastModified = lastModified; - } - - private Object readResolve() throws ObjectStreamException { - if (scheme == null) { - return new S3ResourceId(DEFAULT_SCHEME, bucket, key, size, lastModified); - } - return this; - } - - static S3ResourceId fromComponents(String scheme, String bucket, String key) { - if (!key.startsWith("/")) { - key = "/" + key; - } - return new S3ResourceId(scheme, bucket, key, null, null); - } - - static S3ResourceId fromUri(String uri) { - Matcher m = S3_URI.matcher(uri); - checkArgument(m.matches(), "Invalid S3 URI: [%s]", uri); - String scheme = m.group("SCHEME"); - String bucket = m.group("BUCKET"); - String key = Strings.nullToEmpty(m.group("KEY")); - if (!key.startsWith("/")) { - key = "/" + key; - } - return fromComponents(scheme, bucket, key); - } - - String getBucket() { - return bucket; - } - - String getKey() { - // Skip leading slash - return key.substring(1); - } - - Optional getSize() { - return Optional.fromNullable(size); - } - - S3ResourceId withSize(long size) { - return new S3ResourceId(scheme, bucket, key, size, lastModified); - } - - Optional getLastModified() { - return Optional.fromNullable(lastModified); - } - - S3ResourceId withLastModified(Date lastModified) { - return new S3ResourceId(scheme, bucket, key, size, lastModified); - } - - @Override - public ResourceId resolve(String other, ResolveOptions resolveOptions) { - checkState(isDirectory(), "Expected this resource to be a directory, but was [%s]", toString()); - - if (resolveOptions == StandardResolveOptions.RESOLVE_DIRECTORY) { - if ("..".equals(other)) { - if ("/".equals(key)) { - return this; - } - int parentStopsAt = key.substring(0, key.length() - 1).lastIndexOf('/'); - return fromComponents(scheme, bucket, key.substring(0, parentStopsAt + 1)); - } - - if ("".equals(other)) { - return this; - } - - if (!other.endsWith("/")) { - other += "/"; - } - if (S3_URI.matcher(other).matches()) { - return resolveFromUri(other); - } - return fromComponents(scheme, bucket, key + other); - } - - if (resolveOptions == StandardResolveOptions.RESOLVE_FILE) { - checkArgument( - !other.endsWith("/"), "Cannot resolve a file with a directory path: [%s]", other); - checkArgument(!"..".equals(other), "Cannot resolve parent as file: [%s]", other); - if (S3_URI.matcher(other).matches()) { - return resolveFromUri(other); - } - return fromComponents(scheme, bucket, key + other); - } - - throw new UnsupportedOperationException( - String.format("Unexpected StandardResolveOptions [%s]", resolveOptions)); - } - - private S3ResourceId resolveFromUri(String uri) { - S3ResourceId id = fromUri(uri); - checkArgument( - id.getScheme().equals(scheme), - "Cannot resolve a URI as a child resource unless its scheme is [%s]; instead it was [%s]", - scheme, - id.getScheme()); - return id; - } - - @Override - public ResourceId getCurrentDirectory() { - if (isDirectory()) { - return this; - } - return fromComponents(scheme, getBucket(), key.substring(0, key.lastIndexOf('/') + 1)); - } - - @Override - public String getScheme() { - return scheme; - } - - @Override - public @Nullable String getFilename() { - if (!isDirectory()) { - return key.substring(key.lastIndexOf('/') + 1); - } - if ("/".equals(key)) { - return null; - } - String keyWithoutTrailingSlash = key.substring(0, key.length() - 1); - return keyWithoutTrailingSlash.substring(keyWithoutTrailingSlash.lastIndexOf('/') + 1); - } - - @Override - public boolean isDirectory() { - return key.endsWith("/"); - } - - boolean isWildcard() { - return GLOB_PREFIX.matcher(getKey()).matches(); - } - - String getKeyNonWildcardPrefix() { - Matcher m = GLOB_PREFIX.matcher(getKey()); - checkArgument(m.matches(), String.format("Glob expression: [%s] is not expandable.", getKey())); - return m.group("PREFIX"); - } - - @Override - public String toString() { - return String.format("%s://%s%s", scheme, bucket, key); - } - - @Override - public boolean equals(@Nullable Object obj) { - if (!(obj instanceof S3ResourceId)) { - return false; - } - S3ResourceId o = (S3ResourceId) obj; - - return scheme.equals(o.scheme) && bucket.equals(o.bucket) && key.equals(o.key); - } - - @Override - public int hashCode() { - return Objects.hash(scheme, bucket, key); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3WritableByteChannel.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3WritableByteChannel.java deleted file mode 100644 index 3594ca5b0aaa..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3WritableByteChannel.java +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.AmazonClientException; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest; -import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest; -import com.amazonaws.services.s3.model.InitiateMultipartUploadResult; -import com.amazonaws.services.s3.model.ObjectMetadata; -import com.amazonaws.services.s3.model.PartETag; -import com.amazonaws.services.s3.model.UploadPartRequest; -import com.amazonaws.services.s3.model.UploadPartResult; -import com.amazonaws.util.Base64; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.WritableByteChannel; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; - -/** A writable S3 object, as a {@link WritableByteChannel}. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class S3WritableByteChannel implements WritableByteChannel { - private final AmazonS3 amazonS3; - private final S3FileSystemConfiguration config; - private final S3ResourceId path; - - private final String uploadId; - private final ByteBuffer uploadBuffer; - private final List eTags; - - // AWS S3 parts are 1-indexed, not zero-indexed. - private int partNumber = 1; - private boolean open = true; - private final MessageDigest md5 = md5(); - - S3WritableByteChannel( - AmazonS3 amazonS3, S3ResourceId path, String contentType, S3FileSystemConfiguration config) - throws IOException { - this.amazonS3 = checkNotNull(amazonS3, "amazonS3"); - this.config = checkNotNull(config); - this.path = checkNotNull(path, "path"); - checkArgument( - atMostOne( - config.getSSECustomerKey() != null, - config.getSSEAlgorithm() != null, - config.getSSEAwsKeyManagementParams() != null), - "Either SSECustomerKey (SSE-C) or SSEAlgorithm (SSE-S3)" - + " or SSEAwsKeyManagementParams (SSE-KMS) must not be set at the same time."); - // Amazon S3 API docs: Each part must be at least 5 MB in size, except the last part. - checkArgument( - config.getS3UploadBufferSizeBytes() - >= S3FileSystemConfiguration.MINIMUM_UPLOAD_BUFFER_SIZE_BYTES, - "S3UploadBufferSizeBytes must be at least %s bytes", - S3FileSystemConfiguration.MINIMUM_UPLOAD_BUFFER_SIZE_BYTES); - this.uploadBuffer = ByteBuffer.allocate(config.getS3UploadBufferSizeBytes()); - eTags = new ArrayList<>(); - - ObjectMetadata objectMetadata = new ObjectMetadata(); - objectMetadata.setContentType(contentType); - if (config.getSSEAlgorithm() != null) { - objectMetadata.setSSEAlgorithm(config.getSSEAlgorithm()); - } - InitiateMultipartUploadRequest request = - new InitiateMultipartUploadRequest(path.getBucket(), path.getKey()) - .withStorageClass(config.getS3StorageClass()) - .withObjectMetadata(objectMetadata); - request.setSSECustomerKey(config.getSSECustomerKey()); - request.setSSEAwsKeyManagementParams(config.getSSEAwsKeyManagementParams()); - request.setBucketKeyEnabled(config.getBucketKeyEnabled()); - InitiateMultipartUploadResult result; - try { - result = amazonS3.initiateMultipartUpload(request); - } catch (AmazonClientException e) { - throw new IOException(e); - } - uploadId = result.getUploadId(); - } - - private static MessageDigest md5() { - try { - return MessageDigest.getInstance("MD5"); - } catch (NoSuchAlgorithmException e) { - throw new IllegalStateException(e); - } - } - - @Override - public int write(ByteBuffer sourceBuffer) throws IOException { - if (!isOpen()) { - throw new ClosedChannelException(); - } - - int totalBytesWritten = 0; - while (sourceBuffer.hasRemaining()) { - int position = sourceBuffer.position(); - int bytesWritten = Math.min(sourceBuffer.remaining(), uploadBuffer.remaining()); - totalBytesWritten += bytesWritten; - - if (sourceBuffer.hasArray()) { - // If the underlying array is accessible, direct access is the most efficient approach. - int start = sourceBuffer.arrayOffset() + position; - uploadBuffer.put(sourceBuffer.array(), start, bytesWritten); - md5.update(sourceBuffer.array(), start, bytesWritten); - } else { - // Otherwise, use a readonly copy with an appropriate mark to read the current range of the - // buffer twice. - ByteBuffer copyBuffer = sourceBuffer.asReadOnlyBuffer(); - copyBuffer.mark().limit(position + bytesWritten); - uploadBuffer.put(copyBuffer); - copyBuffer.reset(); - md5.update(copyBuffer); - } - sourceBuffer.position(position + bytesWritten); // move position forward by the bytes written - - if (!uploadBuffer.hasRemaining() || sourceBuffer.hasRemaining()) { - flush(); - } - } - - return totalBytesWritten; - } - - private void flush() throws IOException { - uploadBuffer.flip(); - ByteArrayInputStream inputStream = - new ByteArrayInputStream(uploadBuffer.array(), 0, uploadBuffer.limit()); - - UploadPartRequest request = - new UploadPartRequest() - .withBucketName(path.getBucket()) - .withKey(path.getKey()) - .withUploadId(uploadId) - .withPartNumber(partNumber++) - .withPartSize(uploadBuffer.limit()) - .withMD5Digest(Base64.encodeAsString(md5.digest())) - .withInputStream(inputStream); - request.setSSECustomerKey(config.getSSECustomerKey()); - - UploadPartResult result; - try { - result = amazonS3.uploadPart(request); - } catch (AmazonClientException e) { - throw new IOException(e); - } - uploadBuffer.clear(); - md5.reset(); - eTags.add(result.getPartETag()); - } - - @Override - public boolean isOpen() { - return open; - } - - @Override - public void close() throws IOException { - open = false; - if (uploadBuffer.remaining() > 0) { - flush(); - } - CompleteMultipartUploadRequest request = - new CompleteMultipartUploadRequest() - .withBucketName(path.getBucket()) - .withKey(path.getKey()) - .withUploadId(uploadId) - .withPartETags(eTags); - try { - amazonS3.completeMultipartUpload(request); - } catch (AmazonClientException e) { - throw new IOException(e); - } - } - - @VisibleForTesting - static boolean atMostOne(boolean... values) { - boolean one = false; - for (boolean value : values) { - if (!one && value) { - one = true; - } else if (value) { - return false; - } - } - return true; - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/package-info.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/package-info.java deleted file mode 100644 index ebbf1d8db5a5..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/package-info.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** Defines IO connectors for Amazon Web Services S3. */ -package org.apache.beam.sdk.io.aws.s3; diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/AwsClientsProvider.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/AwsClientsProvider.java deleted file mode 100644 index 6a90c0285f20..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/AwsClientsProvider.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sns; - -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.sns.AmazonSNS; -import java.io.Serializable; - -/** - * Provides instances of AWS clients. - * - *

Please note, that any instance of {@link AwsClientsProvider} must be {@link Serializable} to - * ensure it can be sent to worker machines. - */ -public interface AwsClientsProvider extends Serializable { - - /** @deprecated SnsIO doesn't require a CloudWatch client */ - @Deprecated - @SuppressWarnings("return") - default AmazonCloudWatch getCloudWatchClient() { - return null; - } - - AmazonSNS createSnsPublisher(); -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/BasicSnsProvider.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/BasicSnsProvider.java deleted file mode 100644 index aba3a74ccb2a..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/BasicSnsProvider.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sns; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.sns.AmazonSNS; -import com.amazonaws.services.sns.AmazonSNSClientBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** Basic implementation of {@link AwsClientsProvider} used by default in {@link SnsIO}. */ -class BasicSnsProvider implements AwsClientsProvider { - - private final String accessKey; - private final String secretKey; - private final Regions region; - private final @Nullable String serviceEndpoint; - - BasicSnsProvider( - String accessKey, String secretKey, Regions region, @Nullable String serviceEndpoint) { - checkArgument(accessKey != null, "accessKey can not be null"); - checkArgument(secretKey != null, "secretKey can not be null"); - checkArgument(region != null, "region can not be null"); - this.accessKey = accessKey; - this.secretKey = secretKey; - this.region = region; - this.serviceEndpoint = serviceEndpoint; - } - - private AWSCredentialsProvider getCredentialsProvider() { - return new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)); - } - - @Override - public AmazonSNS createSnsPublisher() { - AmazonSNSClientBuilder clientBuilder = - AmazonSNSClientBuilder.standard().withCredentials(getCredentialsProvider()); - if (serviceEndpoint == null) { - clientBuilder.withRegion(region); - } else { - clientBuilder.withEndpointConfiguration( - new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, region.getName())); - } - return clientBuilder.build(); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoders.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoders.java deleted file mode 100644 index 6d546204d617..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoders.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sns; - -import com.amazonaws.ResponseMetadata; -import com.amazonaws.http.SdkHttpMetadata; -import com.amazonaws.services.sns.model.PublishResult; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.CustomCoder; -import org.apache.beam.sdk.coders.NullableCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.io.aws.coders.AwsCoders; - -/** Coders for SNS {@link PublishResult}. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public final class PublishResultCoders { - - private static final Coder MESSAGE_ID_CODER = StringUtf8Coder.of(); - private static final Coder RESPONSE_METADATA_CODER = - NullableCoder.of(AwsCoders.responseMetadata()); - - private PublishResultCoders() {} - - /** - * Returns a new PublishResult coder which by default serializes only the messageId. - * - * @return the PublishResult coder - */ - public static Coder defaultPublishResult() { - return new PublishResultCoder(null, null); - } - - /** - * Returns a new PublishResult coder which serializes the sdkResponseMetadata and sdkHttpMetadata, - * including the HTTP response headers. - * - * @return the PublishResult coder - */ - public static Coder fullPublishResult() { - return new PublishResultCoder( - RESPONSE_METADATA_CODER, NullableCoder.of(AwsCoders.sdkHttpMetadata())); - } - - /** - * Returns a new PublishResult coder which serializes the sdkResponseMetadata and sdkHttpMetadata, - * but does not include the HTTP response headers. - * - * @return the PublishResult coder - */ - public static Coder fullPublishResultWithoutHeaders() { - return new PublishResultCoder( - RESPONSE_METADATA_CODER, NullableCoder.of(AwsCoders.sdkHttpMetadataWithoutHeaders())); - } - - static class PublishResultCoder extends CustomCoder { - - private final Coder responseMetadataEncoder; - private final Coder sdkHttpMetadataCoder; - - private PublishResultCoder( - Coder responseMetadataEncoder, - Coder sdkHttpMetadataCoder) { - this.responseMetadataEncoder = responseMetadataEncoder; - this.sdkHttpMetadataCoder = sdkHttpMetadataCoder; - } - - @Override - public void encode(PublishResult value, OutputStream outStream) - throws CoderException, IOException { - MESSAGE_ID_CODER.encode(value.getMessageId(), outStream); - if (responseMetadataEncoder != null) { - responseMetadataEncoder.encode(value.getSdkResponseMetadata(), outStream); - } - if (sdkHttpMetadataCoder != null) { - sdkHttpMetadataCoder.encode(value.getSdkHttpMetadata(), outStream); - } - } - - @Override - public PublishResult decode(InputStream inStream) throws CoderException, IOException { - String messageId = MESSAGE_ID_CODER.decode(inStream); - PublishResult publishResult = new PublishResult().withMessageId(messageId); - if (responseMetadataEncoder != null) { - publishResult.setSdkResponseMetadata(responseMetadataEncoder.decode(inStream)); - } - if (sdkHttpMetadataCoder != null) { - publishResult.setSdkHttpMetadata(sdkHttpMetadataCoder.decode(inStream)); - } - return publishResult; - } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - MESSAGE_ID_CODER.verifyDeterministic(); - if (responseMetadataEncoder != null) { - responseMetadataEncoder.verifyDeterministic(); - } - if (sdkHttpMetadataCoder != null) { - sdkHttpMetadataCoder.verifyDeterministic(); - } - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsCoderProviderRegistrar.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsCoderProviderRegistrar.java deleted file mode 100644 index 315435861419..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsCoderProviderRegistrar.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sns; - -import com.amazonaws.services.sns.model.PublishResult; -import com.google.auto.service.AutoService; -import java.util.List; -import org.apache.beam.sdk.coders.CoderProvider; -import org.apache.beam.sdk.coders.CoderProviderRegistrar; -import org.apache.beam.sdk.coders.CoderProviders; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; - -/** A {@link CoderProviderRegistrar} for standard types used with {@link SnsIO}. */ -@AutoService(CoderProviderRegistrar.class) -public class SnsCoderProviderRegistrar implements CoderProviderRegistrar { - @Override - public List getCoderProviders() { - return ImmutableList.of( - CoderProviders.forCoder( - TypeDescriptor.of(PublishResult.class), PublishResultCoders.defaultPublishResult())); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsIO.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsIO.java deleted file mode 100644 index 291026f82f7e..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsIO.java +++ /dev/null @@ -1,420 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sns; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import com.amazonaws.regions.Regions; -import com.amazonaws.services.sns.AmazonSNS; -import com.amazonaws.services.sns.model.GetTopicAttributesResult; -import com.amazonaws.services.sns.model.InternalErrorException; -import com.amazonaws.services.sns.model.PublishRequest; -import com.amazonaws.services.sns.model.PublishResult; -import com.google.auto.value.AutoValue; -import java.io.IOException; -import java.io.Serializable; -import java.util.function.Predicate; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.sdk.util.BackOffUtils; -import org.apache.beam.sdk.util.FluentBackoff; -import org.apache.beam.sdk.util.Sleeper; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TupleTagList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.http.HttpStatus; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * {@link PTransform}s for writing to SNS. - * - *

Writing to SNS

- * - *

Example usage: - * - *

{@code
- * PCollection data = ...;
- *
- * data.apply(SnsIO.write()
- *     .withTopicName("topicName")
- *     .withRetryConfiguration(
- *        SnsIO.RetryConfiguration.create(
- *          4, org.joda.time.Duration.standardSeconds(10)))
- *     .withAWSClientsProvider(new BasicSnsProvider(accessKey, secretKey, region))
- *     .withResultOutputTag(results));
- * }
- * - *

As a client, you need to provide at least the following things: - * - *

    - *
  • name of the SNS topic you're going to write to - *
  • retry configuration - *
  • need to specify AwsClientsProvider. You can pass on the default one BasicSnsProvider - *
  • an output tag where you can get results. Example in SnsIOTest - *
- * - *

By default, the output PublishResult contains only the messageId, all other fields are null. - * If you need the full ResponseMetadata and SdkHttpMetadata you can call {@link - * Write#withFullPublishResult}. If you need the HTTP status code but not the response headers you - * can call {@link Write#withFullPublishResultWithoutHeaders}. - * - * @deprecated Module beam-sdks-java-io-amazon-web-services is deprecated and will be - * eventually removed. Please migrate to {@link org.apache.beam.sdk.io.aws2.sns.SnsIO} in module - * beam-sdks-java-io-amazon-web-services2. - */ -@Deprecated -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public final class SnsIO { - - // Write data tp SNS - public static Write write() { - return new AutoValue_SnsIO_Write.Builder().build(); - } - - /** - * A POJO encapsulating a configuration for retry behavior when issuing requests to SNS. A retry - * will be attempted until the maxAttempts or maxDuration is exceeded, whichever comes first, for - * any of the following exceptions: - * - *

    - *
  • {@link IOException} - *
- */ - @AutoValue - @AutoValue.CopyAnnotations - public abstract static class RetryConfiguration implements Serializable { - private static final Duration DEFAULT_INITIAL_DURATION = Duration.standardSeconds(5); - - @VisibleForTesting - static final RetryPredicate DEFAULT_RETRY_PREDICATE = new DefaultRetryPredicate(); - - abstract int getMaxAttempts(); - - abstract Duration getMaxDuration(); - - abstract Duration getInitialDuration(); - - abstract RetryPredicate getRetryPredicate(); - - abstract Builder builder(); - - public static RetryConfiguration create(int maxAttempts, Duration maxDuration) { - return create(maxAttempts, maxDuration, DEFAULT_INITIAL_DURATION); - } - - @VisibleForTesting - static RetryConfiguration create( - int maxAttempts, Duration maxDuration, Duration initialDuration) { - checkArgument(maxAttempts > 0, "maxAttempts should be greater than 0"); - checkArgument( - maxDuration != null && maxDuration.isLongerThan(Duration.ZERO), - "maxDuration should be greater than 0"); - checkArgument( - initialDuration != null && initialDuration.isLongerThan(Duration.ZERO), - "initialDuration should be greater than 0"); - return new AutoValue_SnsIO_RetryConfiguration.Builder() - .setMaxAttempts(maxAttempts) - .setMaxDuration(maxDuration) - .setInitialDuration(initialDuration) - .setRetryPredicate(DEFAULT_RETRY_PREDICATE) - .build(); - } - - @AutoValue.Builder - abstract static class Builder { - abstract SnsIO.RetryConfiguration.Builder setMaxAttempts(int maxAttempts); - - abstract SnsIO.RetryConfiguration.Builder setMaxDuration(Duration maxDuration); - - abstract SnsIO.RetryConfiguration.Builder setInitialDuration(Duration initialDuration); - - abstract SnsIO.RetryConfiguration.Builder setRetryPredicate(RetryPredicate retryPredicate); - - abstract SnsIO.RetryConfiguration build(); - } - - /** - * An interface used to control if we retry the SNS Publish call when a {@link Throwable} - * occurs. If {@link RetryPredicate#test(Object)} returns true, {@link Write} tries to resend - * the requests to SNS if the {@link RetryConfiguration} permits it. - */ - @FunctionalInterface - interface RetryPredicate extends Predicate, Serializable {} - - private static class DefaultRetryPredicate implements RetryPredicate { - private static final ImmutableSet ELIGIBLE_CODES = - ImmutableSet.of(HttpStatus.SC_SERVICE_UNAVAILABLE); - - @Override - public boolean test(Throwable throwable) { - return (throwable instanceof IOException - || (throwable instanceof InternalErrorException) - || (throwable instanceof InternalErrorException - && ELIGIBLE_CODES.contains(((InternalErrorException) throwable).getStatusCode()))); - } - } - } - - /** Implementation of {@link #write}. */ - @AutoValue - @AutoValue.CopyAnnotations - public abstract static class Write - extends PTransform, PCollectionTuple> { - - abstract @Nullable String getTopicName(); - - abstract @Nullable AwsClientsProvider getAWSClientsProvider(); - - abstract @Nullable RetryConfiguration getRetryConfiguration(); - - abstract @Nullable TupleTag getResultOutputTag(); - - abstract @Nullable Coder getCoder(); - - abstract Builder builder(); - - @AutoValue.Builder - abstract static class Builder { - - abstract Builder setTopicName(String topicName); - - abstract Builder setAWSClientsProvider(AwsClientsProvider clientProvider); - - abstract Builder setRetryConfiguration(RetryConfiguration retryConfiguration); - - abstract Builder setResultOutputTag(TupleTag results); - - abstract Builder setCoder(Coder coder); - - abstract Write build(); - } - - /** - * Specify the SNS topic which will be used for writing, this name is mandatory. - * - * @param topicName topicName - */ - public Write withTopicName(String topicName) { - return builder().setTopicName(topicName).build(); - } - - /** - * Allows to specify custom {@link AwsClientsProvider}. {@link AwsClientsProvider} creates new - * {@link AmazonSNS} which is later used for writing to a SNS topic. - */ - public Write withAWSClientsProvider(AwsClientsProvider awsClientsProvider) { - return builder().setAWSClientsProvider(awsClientsProvider).build(); - } - - /** - * Specify credential details and region to be used to write to SNS. If you need more - * sophisticated credential protocol, then you should look at {@link - * Write#withAWSClientsProvider(AwsClientsProvider)}. - */ - public Write withAWSClientsProvider(String awsAccessKey, String awsSecretKey, Regions region) { - return withAWSClientsProvider(awsAccessKey, awsSecretKey, region, null); - } - - /** - * Specify credential details and region to be used to write to SNS. If you need more - * sophisticated credential protocol, then you should look at {@link - * Write#withAWSClientsProvider(AwsClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with Kinesis service emulator. - */ - public Write withAWSClientsProvider( - String awsAccessKey, String awsSecretKey, Regions region, String serviceEndpoint) { - return withAWSClientsProvider( - new BasicSnsProvider(awsAccessKey, awsSecretKey, region, serviceEndpoint)); - } - - /** - * Provides configuration to retry a failed request to publish a message to SNS. Users should - * consider that retrying might compound the underlying problem which caused the initial - * failure. Users should also be aware that once retrying is exhausted the error is surfaced to - * the runner which may then opt to retry the current partition in entirety or abort if - * the max number of retries of the runner is completed. Retrying uses an exponential backoff - * algorithm, with minimum backoff of 5 seconds and then surfacing the error once the maximum - * number of retries or maximum configuration duration is exceeded. - * - *

Example use: - * - *

{@code
-     * SnsIO.write()
-     *   .withRetryConfiguration(SnsIO.RetryConfiguration.create(5, Duration.standardMinutes(1))
-     *   ...
-     * }
- * - * @param retryConfiguration the rules which govern the retry behavior - * @return the {@link Write} with retrying configured - */ - public Write withRetryConfiguration(RetryConfiguration retryConfiguration) { - checkArgument(retryConfiguration != null, "retryConfiguration is required"); - return builder().setRetryConfiguration(retryConfiguration).build(); - } - - /** Tuple tag to store results. Mandatory field. */ - public Write withResultOutputTag(TupleTag results) { - return builder().setResultOutputTag(results).build(); - } - - /** - * Encode the full {@code PublishResult} object, including sdkResponseMetadata and - * sdkHttpMetadata with the HTTP response headers. - */ - public Write withFullPublishResult() { - return withCoder(PublishResultCoders.fullPublishResult()); - } - - /** - * Encode the full {@code PublishResult} object, including sdkResponseMetadata and - * sdkHttpMetadata but excluding the HTTP response headers. - */ - public Write withFullPublishResultWithoutHeaders() { - return withCoder(PublishResultCoders.fullPublishResultWithoutHeaders()); - } - - /** Encode the {@code PublishResult} with the given coder. */ - public Write withCoder(Coder coder) { - return builder().setCoder(coder).build(); - } - - @Override - public PCollectionTuple expand(PCollection input) { - LoggerFactory.getLogger(SnsIO.class) - .warn( - "You are using a deprecated IO for Sns. Please migrate to module " - + "'org.apache.beam:beam-sdks-java-io-amazon-web-services2'."); - - checkArgument(getTopicName() != null, "withTopicName() is required"); - PCollectionTuple result = - input.apply( - ParDo.of(new SnsWriterFn(this)) - .withOutputTags(getResultOutputTag(), TupleTagList.empty())); - if (getCoder() != null) { - result.get(getResultOutputTag()).setCoder(getCoder()); - } - return result; - } - - static class SnsWriterFn extends DoFn { - @VisibleForTesting - static final String RETRY_ATTEMPT_LOG = "Error writing to SNS. Retry attempt[{}]"; - - private transient FluentBackoff retryBackoff; // defaults to no retries - private static final Logger LOG = LoggerFactory.getLogger(SnsWriterFn.class); - private static final Counter SNS_WRITE_FAILURES = - Metrics.counter(SnsWriterFn.class, "SNS_Write_Failures"); - - private final SnsIO.Write spec; - private transient AmazonSNS producer; - - SnsWriterFn(SnsIO.Write spec) { - this.spec = spec; - } - - @Setup - public void setup() throws Exception { - // Initialize SnsPublisher - producer = spec.getAWSClientsProvider().createSnsPublisher(); - checkArgument( - topicExists(producer, spec.getTopicName()), - "Topic %s does not exist", - spec.getTopicName()); - - retryBackoff = FluentBackoff.DEFAULT.withMaxRetries(0); // default to no retrying - if (spec.getRetryConfiguration() != null) { - retryBackoff = - retryBackoff - .withMaxRetries(spec.getRetryConfiguration().getMaxAttempts() - 1) - .withInitialBackoff(spec.getRetryConfiguration().getInitialDuration()) - .withMaxCumulativeBackoff(spec.getRetryConfiguration().getMaxDuration()); - } - } - - @ProcessElement - public void processElement(ProcessContext context) throws Exception { - PublishRequest request = context.element(); - Sleeper sleeper = Sleeper.DEFAULT; - BackOff backoff = retryBackoff.backoff(); - int attempt = 0; - while (true) { - attempt++; - try { - PublishResult pr = producer.publish(request); - context.output(pr); - break; - } catch (Exception ex) { - // Fail right away if there is no retry configuration - if (spec.getRetryConfiguration() == null - || !spec.getRetryConfiguration().getRetryPredicate().test(ex)) { - SNS_WRITE_FAILURES.inc(); - LOG.info("Unable to publish message {}.", request.getMessage(), ex); - throw new IOException("Error writing to SNS (no attempt made to retry)", ex); - } - - if (!BackOffUtils.next(sleeper, backoff)) { - throw new IOException( - String.format( - "Error writing to SNS after %d attempt(s). No more attempts allowed", - attempt), - ex); - } else { - // Note: this used in test cases to verify behavior - LOG.warn(RETRY_ATTEMPT_LOG, attempt, ex); - } - } - } - } - - @Teardown - public void tearDown() { - if (producer != null) { - producer.shutdown(); - producer = null; - } - } - - @SuppressWarnings({"checkstyle:illegalCatch"}) - private static boolean topicExists(AmazonSNS client, String topicName) { - try { - GetTopicAttributesResult topicAttributesResult = client.getTopicAttributes(topicName); - return topicAttributesResult != null - && topicAttributesResult.getSdkHttpMetadata().getHttpStatusCode() == 200; - } catch (Exception e) { - LOG.warn("Error checking whether topic {} exists.", topicName, e); - throw e; - } - } - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/package-info.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/package-info.java deleted file mode 100644 index a1895cf4ce6d..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/package-info.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** Defines IO connectors for Amazon Web Services SNS. */ -package org.apache.beam.sdk.io.aws.sns; diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsCheckpointMark.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsCheckpointMark.java deleted file mode 100644 index b3e23bff5554..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsCheckpointMark.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - -import java.io.IOException; -import java.io.Serializable; -import java.util.List; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Objects; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.checkerframework.checker.nullness.qual.Nullable; - -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class SqsCheckpointMark implements UnboundedSource.CheckpointMark, Serializable { - - /** - * If the checkpoint is for persisting: the reader who's snapshotted state we are persisting. If - * the checkpoint is for restoring: {@literal null}. Not persisted in durable checkpoint. CAUTION: - * Between a checkpoint being taken and {@link #finalizeCheckpoint()} being called the 'true' - * active reader may have changed. - */ - private transient @Nullable SqsUnboundedReader reader; - - /** - * If the checkpoint is for persisting: The ids of messages which have been passed downstream - * since the last checkpoint. If the checkpoint is for restoring: {@literal null}. Not persisted - * in durable checkpoint. - */ - private @Nullable List safeToDeleteIds; - - /** - * If the checkpoint is for persisting: The receipt handles of messages which have been received - * from SQS but not yet passed downstream at the time of the snapshot. If the checkpoint is for - * restoring: Same, but recovered from durable storage. - */ - @VisibleForTesting final List notYetReadReceipts; - - public SqsCheckpointMark( - SqsUnboundedReader reader, List messagesToDelete, List notYetReadReceipts) { - this.reader = reader; - this.safeToDeleteIds = ImmutableList.copyOf(messagesToDelete); - this.notYetReadReceipts = ImmutableList.copyOf(notYetReadReceipts); - } - - @Override - public void finalizeCheckpoint() throws IOException { - checkState(reader != null && safeToDeleteIds != null, "Cannot finalize a restored checkpoint"); - // Even if the 'true' active reader has changed since the checkpoint was taken we are - // fine: - // - The underlying SQS topic will not have changed, so the following deletes will still - // go to the right place. - // - We'll delete the ACK ids from the readers in-flight state, but that only affect - // flow control and stats, neither of which are relevant anymore. - try { - reader.delete(safeToDeleteIds); - } finally { - int remainingInFlight = reader.numInFlightCheckpoints.decrementAndGet(); - checkState(remainingInFlight >= 0, "Miscounted in-flight checkpoints"); - reader.maybeCloseClient(); - reader = null; - safeToDeleteIds = null; - } - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - SqsCheckpointMark that = (SqsCheckpointMark) o; - return Objects.equal(safeToDeleteIds, that.safeToDeleteIds); - } - - @Override - public int hashCode() { - return Objects.hashCode(safeToDeleteIds); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsConfiguration.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsConfiguration.java deleted file mode 100644 index 3c798112325e..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsConfiguration.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import java.io.IOException; -import java.io.Serializable; -import org.apache.beam.sdk.io.aws.options.AwsModule; -import org.apache.beam.sdk.io.aws.options.AwsOptions; - -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class SqsConfiguration implements Serializable { - - private String awsRegion; - private String awsCredentialsProviderString; - private String awsClientConfigurationString; - - public SqsConfiguration(AwsOptions awsOptions) { - ObjectMapper om = new ObjectMapper(); - om.registerModule(new AwsModule()); - try { - this.awsCredentialsProviderString = - om.writeValueAsString(awsOptions.getAwsCredentialsProvider()); - } catch (JsonProcessingException e) { - this.awsCredentialsProviderString = null; - } - - try { - this.awsClientConfigurationString = - om.writeValueAsString(awsOptions.getClientConfiguration()); - } catch (JsonProcessingException e) { - this.awsClientConfigurationString = null; - } - - this.awsRegion = awsOptions.getAwsRegion(); - } - - public AWSCredentialsProvider getAwsCredentialsProvider() { - ObjectMapper om = new ObjectMapper(); - om.registerModule(new AwsModule()); - try { - return om.readValue(awsCredentialsProviderString, AWSCredentialsProvider.class); - } catch (IOException e) { - return null; - } - } - - public ClientConfiguration getClientConfiguration() { - ObjectMapper om = new ObjectMapper(); - om.registerModule(new AwsModule()); - try { - return om.readValue(awsClientConfigurationString, ClientConfiguration.class); - } catch (IOException e) { - return null; - } - } - - public String getAwsRegion() { - return awsRegion; - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsIO.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsIO.java deleted file mode 100644 index 26ca03c95c33..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsIO.java +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import com.amazonaws.services.sqs.AmazonSQS; -import com.amazonaws.services.sqs.AmazonSQSClientBuilder; -import com.amazonaws.services.sqs.model.Message; -import com.amazonaws.services.sqs.model.SendMessageRequest; -import com.google.auto.value.AutoValue; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.io.aws.options.AwsOptions; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PBegin; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PDone; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; -import org.slf4j.LoggerFactory; - -/** - * An unbounded source for Amazon Simple Queue Service (SQS). - * - *

Reading from an SQS queue

- * - *

The {@link SqsIO} {@link Read} returns an unbounded {@link PCollection} of {@link - * com.amazonaws.services.sqs.model.Message} containing the received messages. Note: This source - * does not currently advance the watermark when no new messages are received. - * - *

To configure an SQS source, you have to provide the queueUrl to connect to. The following - * example illustrates how to configure the source: - * - *

{@code
- * pipeline.apply(SqsIO.read().withQueueUrl(queueUrl))
- * }
- * - *

Writing to an SQS queue

- * - *

The following example illustrates how to use the sink: - * - *

{@code
- * pipeline
- *   .apply(...) // returns PCollection
- *   .apply(SqsIO.write())
- * }
- * - *

Additional Configuration

- * - *

Additional configuration can be provided via {@link AwsOptions} from command line args or in - * code. For example, if you wanted to provide a secret access key via code: - * - *

{@code
- * PipelineOptions pipelineOptions = PipelineOptionsFactory.fromArgs(args).withValidation().create();
- * AwsOptions awsOptions = pipelineOptions.as(AwsOptions.class);
- * BasicAWSCredentials awsCreds = new BasicAWSCredentials("accesskey", "secretkey");
- * awsOptions.setAwsCredentialsProvider(new AWSStaticCredentialsProvider(awsCreds));
- * Pipeline pipeline = Pipeline.create(options);
- * }
- * - *

For more information on the available options see {@link AwsOptions}. - * - * @deprecated Module beam-sdks-java-io-amazon-web-services is deprecated and will be - * eventually removed. Please migrate to {@link org.apache.beam.sdk.io.aws2.sqs.SqsIO} in module - * beam-sdks-java-io-amazon-web-services2. - */ -@Deprecated -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class SqsIO { - - public static Read read() { - return new AutoValue_SqsIO_Read.Builder() - .setCoder(SqsMessageCoder.of()) - .setMaxNumRecords(Long.MAX_VALUE) - .build(); - } - - public static Write write() { - return new AutoValue_SqsIO_Write.Builder().build(); - } - - private SqsIO() {} - - /** - * A {@link PTransform} to read/receive messages from SQS. See {@link SqsIO} for more information - * on usage and configuration. - */ - @AutoValue - public abstract static class Read extends PTransform> { - - abstract Coder coder(); - - abstract @Nullable String queueUrl(); - - abstract long maxNumRecords(); - - abstract @Nullable Duration maxReadTime(); - - abstract Builder toBuilder(); - - @AutoValue.Builder - abstract static class Builder { - abstract Builder setCoder(Coder coder); - - abstract Builder setQueueUrl(String queueUrl); - - abstract Builder setMaxNumRecords(long maxNumRecords); - - abstract Builder setMaxReadTime(Duration maxReadTime); - - abstract Read build(); - } - - /** - * Optionally set a custom {@link Message} output coder if you need to access further (message) - * attributes. - * - *

The default {@link SqsMessageCoder} only supports `SentTimestamp` and - * `requestTimeMsSinceEpoch`. - */ - public Read withCoder(Coder coder) { - return toBuilder().setCoder(coder).build(); - } - - /** - * Define the max number of records received by the {@link Read}. When the max number of records - * is lower than {@code Long.MAX_VALUE}, the {@link Read} will provide a bounded {@link - * PCollection}. - */ - public Read withMaxNumRecords(long maxNumRecords) { - return toBuilder().setMaxNumRecords(maxNumRecords).build(); - } - - /** - * Define the max read time (duration) while the {@link Read} will receive messages. When this - * max read time is not null, the {@link Read} will provide a bounded {@link PCollection}. - */ - public Read withMaxReadTime(Duration maxReadTime) { - return toBuilder().setMaxReadTime(maxReadTime).build(); - } - - /** Define the queueUrl used by the {@link Read} to receive messages from SQS. */ - public Read withQueueUrl(String queueUrl) { - checkArgument(queueUrl != null, "queueUrl can not be null"); - checkArgument(!queueUrl.isEmpty(), "queueUrl can not be empty"); - return toBuilder().setQueueUrl(queueUrl).build(); - } - - @Override - public PCollection expand(PBegin input) { - LoggerFactory.getLogger(SqsIO.class) - .warn( - "You are using a deprecated IO for Sqs. Please migrate to module " - + "'org.apache.beam:beam-sdks-java-io-amazon-web-services2'."); - - org.apache.beam.sdk.io.Read.Unbounded unbounded = - org.apache.beam.sdk.io.Read.from( - new SqsUnboundedSource( - this, - new SqsConfiguration(input.getPipeline().getOptions().as(AwsOptions.class)), - coder())); - - PTransform> transform = unbounded; - - if (maxNumRecords() < Long.MAX_VALUE || maxReadTime() != null) { - transform = unbounded.withMaxReadTime(maxReadTime()).withMaxNumRecords(maxNumRecords()); - } - - return input.getPipeline().apply(transform); - } - } - - /** - * A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage - * and configuration. - */ - @AutoValue - public abstract static class Write extends PTransform, PDone> { - abstract Builder toBuilder(); - - @AutoValue.Builder - abstract static class Builder { - abstract Write build(); - } - - @Override - public PDone expand(PCollection input) { - LoggerFactory.getLogger(SqsIO.class) - .warn( - "You are using a deprecated IO for Sqs. Please migrate to module " - + "'org.apache.beam:beam-sdks-java-io-amazon-web-services2'."); - - input.apply( - ParDo.of( - new SqsWriteFn( - new SqsConfiguration(input.getPipeline().getOptions().as(AwsOptions.class))))); - return PDone.in(input.getPipeline()); - } - } - - private static class SqsWriteFn extends DoFn { - private final SqsConfiguration sqsConfiguration; - private transient AmazonSQS sqs; - - SqsWriteFn(SqsConfiguration sqsConfiguration) { - this.sqsConfiguration = sqsConfiguration; - } - - @Setup - public void setup() { - sqs = - AmazonSQSClientBuilder.standard() - .withClientConfiguration(sqsConfiguration.getClientConfiguration()) - .withCredentials(sqsConfiguration.getAwsCredentialsProvider()) - .withRegion(sqsConfiguration.getAwsRegion()) - .build(); - } - - @ProcessElement - public void processElement(ProcessContext processContext) throws Exception { - sqs.sendMessage(processContext.element()); - } - - @Teardown - public void teardown() throws Exception { - if (sqs != null) { - sqs.shutdown(); - } - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsMessageCoder.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsMessageCoder.java deleted file mode 100644 index 792642c17609..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsMessageCoder.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import static com.amazonaws.services.sqs.model.MessageSystemAttributeName.SentTimestamp; -import static org.apache.beam.sdk.io.aws.sqs.SqsUnboundedReader.REQUEST_TIME; - -import com.amazonaws.services.sqs.model.Message; -import com.amazonaws.services.sqs.model.MessageAttributeValue; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.NullableCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.checkerframework.checker.nullness.qual.NonNull; - -/** - * Deterministic coder for an AWS Sdk SQS message. - * - *

This encoder only keeps the `SentTimestamp` attribute as well as the `requestTimeMsSinceEpoch` - * message attribute, other attributes are dropped. You may provide your own coder in case you need - * to access further attributes. - */ -class SqsMessageCoder extends AtomicCoder { - private static final Coder STRING_CODER = StringUtf8Coder.of(); - private static final NullableCoder OPT_STRING_CODER = - NullableCoder.of(StringUtf8Coder.of()); - - private static final Coder INSTANCE = new SqsMessageCoder(); - - static Coder of() { - return INSTANCE; - } - - private SqsMessageCoder() {} - - @Override - public void encode(Message value, OutputStream out) throws IOException { - STRING_CODER.encode(value.getMessageId(), out); - STRING_CODER.encode(value.getReceiptHandle(), out); - OPT_STRING_CODER.encode(value.getBody(), out); - OPT_STRING_CODER.encode(value.getAttributes().get(SentTimestamp.toString()), out); - MessageAttributeValue reqTime = value.getMessageAttributes().get(REQUEST_TIME); - OPT_STRING_CODER.encode(reqTime != null ? reqTime.getStringValue() : null, out); - } - - @Override - public Message decode(InputStream in) throws IOException { - Message msg = new Message(); - msg.setMessageId(STRING_CODER.decode(in)); - msg.setReceiptHandle(STRING_CODER.decode(in)); - - // SQS library not annotated, but this coder assumes null is allowed (documentation does not - // specify) - @SuppressWarnings("nullness") - @NonNull - String body = OPT_STRING_CODER.decode(in); - msg.setBody(body); - - String sentAt = OPT_STRING_CODER.decode(in); - if (sentAt != null) { - msg.addAttributesEntry(SentTimestamp.toString(), sentAt); - } - - String reqTime = OPT_STRING_CODER.decode(in); - if (reqTime != null) { - msg.addMessageAttributesEntry( - REQUEST_TIME, new MessageAttributeValue().withStringValue(reqTime)); - } - return msg; - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedReader.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedReader.java deleted file mode 100644 index 1fd5e38f5464..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedReader.java +++ /dev/null @@ -1,944 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.stream.Collectors.groupingBy; -import static java.util.stream.Collectors.toMap; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - -import com.amazonaws.services.sqs.AmazonSQS; -import com.amazonaws.services.sqs.AmazonSQSClientBuilder; -import com.amazonaws.services.sqs.model.BatchResultErrorEntry; -import com.amazonaws.services.sqs.model.ChangeMessageVisibilityBatchRequestEntry; -import com.amazonaws.services.sqs.model.ChangeMessageVisibilityBatchResult; -import com.amazonaws.services.sqs.model.DeleteMessageBatchRequestEntry; -import com.amazonaws.services.sqs.model.DeleteMessageBatchResult; -import com.amazonaws.services.sqs.model.GetQueueAttributesRequest; -import com.amazonaws.services.sqs.model.Message; -import com.amazonaws.services.sqs.model.MessageAttributeValue; -import com.amazonaws.services.sqs.model.MessageSystemAttributeName; -import com.amazonaws.services.sqs.model.QueueAttributeName; -import com.amazonaws.services.sqs.model.ReceiveMessageRequest; -import com.amazonaws.services.sqs.model.ReceiveMessageResult; -import java.io.IOException; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Objects; -import java.util.Queue; -import java.util.Set; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Max; -import org.apache.beam.sdk.transforms.Min; -import org.apache.beam.sdk.transforms.Sum; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.util.BucketingFunction; -import org.apache.beam.sdk.util.MovingFunction; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class SqsUnboundedReader extends UnboundedSource.UnboundedReader { - private static final Logger LOG = LoggerFactory.getLogger(SqsUnboundedReader.class); - - /** Request time attribute in {@link Message#getMessageAttributes()}. */ - static final String REQUEST_TIME = "requestTimeMsSinceEpoch"; - - /** Maximum number of messages to pull from SQS per request. */ - public static final int MAX_NUMBER_OF_MESSAGES = 10; - - /** Maximum times to retry batch SQS operations upon partial success. */ - private static final int BATCH_OPERATION_MAX_RETIRES = 5; - - /** Timeout for round trip from receiving a message to finally deleting it from SQS. */ - private static final Duration PROCESSING_TIMEOUT = Duration.standardMinutes(2); - - /** - * Percentage of visibility timeout by which to extend visibility timeout when they are near - * timeout. - */ - private static final int VISIBILITY_EXTENSION_PCT = 50; - - /** - * Percentage of ack timeout we should use as a safety margin. We'll try to extend visibility - * timeout by this margin before the visibility timeout actually expires. - */ - private static final int VISIBILITY_SAFETY_PCT = 20; - - /** - * For stats only: How close we can get to an visibility deadline before we risk it being already - * considered passed by SQS. - */ - private static final Duration VISIBILITY_TOO_LATE = Duration.standardSeconds(2); - - /** Maximum number of message ids per delete or visibility extension call. */ - private static final int DELETE_BATCH_SIZE = 10; - - /** Maximum number of messages in flight. */ - private static final int MAX_IN_FLIGHT = 20000; - - /** Maximum number of recent messages for calculating average message size. */ - private static final int MAX_AVG_BYTE_MESSAGES = 20; - - /** Period of samples to determine watermark and other stats. */ - private static final Duration SAMPLE_PERIOD = Duration.standardMinutes(1); - - /** Period of updates to determine watermark and other stats. */ - private static final Duration SAMPLE_UPDATE = Duration.standardSeconds(5); - - /** Period for logging stats. */ - private static final Duration LOG_PERIOD = Duration.standardSeconds(30); - - /** Minimum number of unread messages required before considering updating watermark. */ - private static final int MIN_WATERMARK_MESSAGES = 10; - - /** - * Minimum number of SAMPLE_UPDATE periods over which unread messages should be spread before - * considering updating watermark. - */ - private static final int MIN_WATERMARK_SPREAD = 2; - - private static final Combine.BinaryCombineLongFn MIN = Min.ofLongs(); - - private static final Combine.BinaryCombineLongFn MAX = Max.ofLongs(); - - private static final Combine.BinaryCombineLongFn SUM = Sum.ofLongs(); - - /** For access to topic and SQS client. */ - private final SqsUnboundedSource source; - - /** - * The closed state of this {@link SqsUnboundedReader}. If true, the reader has not yet been - * closed, and it will have a non-null value within {@link #SqsUnboundedReader}. - */ - private AtomicBoolean active = new AtomicBoolean(true); - - /** SQS client of this reader instance. */ - private AmazonSQS sqsClient = null; - - /** The current message, or {@literal null} if none. */ - private Message current; - - /** - * Messages we have received from SQS and not yet delivered downstream. We preserve their order. - */ - final Queue messagesNotYetRead; - - /** Message ids of messages we have delivered downstream but not yet deleted. */ - private Set safeToDeleteIds; - - /** - * Visibility timeout, in ms, as set on subscription when we first start reading. Not updated - * thereafter. -1 if not yet determined. - */ - private long visibilityTimeoutMs; - - /** Byte size of undecoded elements in {@link #messagesNotYetRead}. */ - private long notYetReadBytes; - - /** Byte size of recent messages. */ - private EvictingQueue recentMessageBytes; - - /** - * Bucketed map from received time (as system time, ms since epoch) to message timestamps (mssince - * epoch) of all received but not-yet read messages. Used to estimate watermark. - */ - private BucketingFunction minUnreadTimestampMsSinceEpoch; - - /** - * Minimum of timestamps (ms since epoch) of all recently read messages. Used to estimate - * watermark. - */ - private MovingFunction minReadTimestampMsSinceEpoch; - - /** Number of recent empty receives. */ - private MovingFunction numEmptyReceives; - - private static class InFlightState { - /** Receipt handle of message. */ - String receiptHandle; - - /** When request which yielded message was issued. */ - long requestTimeMsSinceEpoch; - - /** - * When SQS will consider this message's visibility timeout to timeout and thus it needs to be - * extended. - */ - long visibilityDeadlineMsSinceEpoch; - - public InFlightState( - String receiptHandle, long requestTimeMsSinceEpoch, long visibilityDeadlineMsSinceEpoch) { - this.receiptHandle = receiptHandle; - this.requestTimeMsSinceEpoch = requestTimeMsSinceEpoch; - this.visibilityDeadlineMsSinceEpoch = visibilityDeadlineMsSinceEpoch; - } - } - - /** - * Map from message ids of messages we have received from SQS but not yet deleted to their in - * flight state. Ordered from earliest to latest visibility deadline. - */ - private final LinkedHashMap inFlight; - - /** - * Batches of successfully deleted message ids which need to be pruned from the above. CAUTION: - * Accessed by both reader and checkpointing threads. - */ - private final Queue> deletedIds; - - /** - * System time (ms since epoch) we last received a message from SQS, or -1 if not yet received any - * messages. - */ - private long lastReceivedMsSinceEpoch; - - /** The last reported watermark (ms since epoch), or beginning of time if none yet reported. */ - private long lastWatermarkMsSinceEpoch; - - /** Stats only: System time (ms since epoch) we last logs stats, or -1 if never. */ - private long lastLogTimestampMsSinceEpoch; - - /** Stats only: Total number of messages received. */ - private long numReceived; - - /** Stats only: Number of messages which have recently been received. */ - private MovingFunction numReceivedRecently; - - /** Stats only: Number of messages which have recently had their deadline extended. */ - private MovingFunction numExtendedDeadlines; - - /** - * Stats only: Number of messages which have recently had their deadline extended even though it - * may be too late to do so. - */ - private MovingFunction numLateDeadlines; - - /** Stats only: Number of messages which have recently been deleted. */ - private MovingFunction numDeleted; - - /** - * Stats only: Number of messages which have recently expired (visibility timeout were extended - * for too long). - */ - private MovingFunction numExpired; - - /** Stats only: Number of messages which have recently been returned to visible on SQS. */ - private MovingFunction numReleased; - - /** Stats only: Number of message bytes which have recently been read by downstream consumer. */ - private MovingFunction numReadBytes; - - /** - * Stats only: Minimum of timestamp (ms since epoch) of all recently received messages. Used to - * estimate timestamp skew. Does not contribute to watermark estimator. - */ - private MovingFunction minReceivedTimestampMsSinceEpoch; - - /** - * Stats only: Maximum of timestamp (ms since epoch) of all recently received messages. Used to - * estimate timestamp skew. - */ - private MovingFunction maxReceivedTimestampMsSinceEpoch; - - /** Stats only: Minimum of recent estimated watermarks (ms since epoch). */ - private MovingFunction minWatermarkMsSinceEpoch; - - /** Stats ony: Maximum of recent estimated watermarks (ms since epoch). */ - private MovingFunction maxWatermarkMsSinceEpoch; - - /** - * Stats only: Number of messages with timestamps strictly behind the estimated watermark at the - * time they are received. These may be considered 'late' by downstream computations. - */ - private MovingFunction numLateMessages; - - /** - * Stats only: Current number of checkpoints in flight. CAUTION: Accessed by both checkpointing - * and reader threads. - */ - AtomicInteger numInFlightCheckpoints; - - /** Stats only: Maximum number of checkpoints in flight at any time. */ - private int maxInFlightCheckpoints; - - private static MovingFunction newFun(Combine.BinaryCombineLongFn function) { - return new MovingFunction( - SAMPLE_PERIOD.getMillis(), - SAMPLE_UPDATE.getMillis(), - MIN_WATERMARK_SPREAD, - MIN_WATERMARK_MESSAGES, - function); - } - - public SqsUnboundedReader(SqsUnboundedSource source, SqsCheckpointMark sqsCheckpointMark) - throws IOException { - this.source = source; - - messagesNotYetRead = new ArrayDeque<>(); - safeToDeleteIds = new HashSet<>(); - inFlight = new LinkedHashMap<>(); - deletedIds = new ConcurrentLinkedQueue<>(); - visibilityTimeoutMs = -1; - notYetReadBytes = 0; - recentMessageBytes = EvictingQueue.create(MAX_AVG_BYTE_MESSAGES); - minUnreadTimestampMsSinceEpoch = - new BucketingFunction( - SAMPLE_UPDATE.getMillis(), MIN_WATERMARK_SPREAD, MIN_WATERMARK_MESSAGES, MIN); - minReadTimestampMsSinceEpoch = newFun(MIN); - numEmptyReceives = newFun(SUM); - lastReceivedMsSinceEpoch = -1; - lastWatermarkMsSinceEpoch = BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis(); - current = null; - lastLogTimestampMsSinceEpoch = -1; - numReceived = 0L; - numReceivedRecently = newFun(SUM); - numExtendedDeadlines = newFun(SUM); - numLateDeadlines = newFun(SUM); - numDeleted = newFun(SUM); - numExpired = newFun(SUM); - numReleased = newFun(SUM); - numReadBytes = newFun(SUM); - minReceivedTimestampMsSinceEpoch = newFun(MIN); - maxReceivedTimestampMsSinceEpoch = newFun(MAX); - minWatermarkMsSinceEpoch = newFun(MIN); - maxWatermarkMsSinceEpoch = newFun(MAX); - numLateMessages = newFun(SUM); - numInFlightCheckpoints = new AtomicInteger(); - maxInFlightCheckpoints = 0; - - if (sqsCheckpointMark != null) { - long nowMsSinceEpoch = now(); - initClient(); - extendBatch(nowMsSinceEpoch, sqsCheckpointMark.notYetReadReceipts, 0); - numReleased.add(nowMsSinceEpoch, sqsCheckpointMark.notYetReadReceipts.size()); - } - } - - @Override - public Instant getWatermark() { - - // NOTE: We'll allow the watermark to go backwards. The underlying runner is responsible - // for aggregating all reported watermarks and ensuring the aggregate is latched. - // If we attempt to latch locally then it is possible a temporary starvation of one reader - // could cause its estimated watermark to fast forward to current system time. Then when - // the reader resumes its watermark would be unable to resume tracking. - // By letting the underlying runner latch we avoid any problems due to localized starvation. - long nowMsSinceEpoch = now(); - long readMin = minReadTimestampMsSinceEpoch.get(nowMsSinceEpoch); - long unreadMin = minUnreadTimestampMsSinceEpoch.get(); - if (readMin == Long.MAX_VALUE - && unreadMin == Long.MAX_VALUE - && numEmptyReceives.get(nowMsSinceEpoch) > 0 - && nowMsSinceEpoch > lastReceivedMsSinceEpoch + SAMPLE_PERIOD.getMillis()) { - // We don't currently have any unread messages pending, we have not had any messages - // read for a while, and we have not received any new messages from SQS for a while. - // Advance watermark to current time. - // TODO: Estimate a timestamp lag. - lastWatermarkMsSinceEpoch = nowMsSinceEpoch; - } else if (minReadTimestampMsSinceEpoch.isSignificant() - || minUnreadTimestampMsSinceEpoch.isSignificant()) { - // Take minimum of the timestamps in all unread messages and recently read messages. - lastWatermarkMsSinceEpoch = Math.min(readMin, unreadMin); - } - // else: We're not confident enough to estimate a new watermark. Stick with the old one. - minWatermarkMsSinceEpoch.add(nowMsSinceEpoch, lastWatermarkMsSinceEpoch); - maxWatermarkMsSinceEpoch.add(nowMsSinceEpoch, lastWatermarkMsSinceEpoch); - return new Instant(lastWatermarkMsSinceEpoch); - } - - @Override - public Message getCurrent() throws NoSuchElementException { - if (current == null) { - throw new NoSuchElementException(); - } - return current; - } - - @Override - public Instant getCurrentTimestamp() throws NoSuchElementException { - if (current == null) { - throw new NoSuchElementException(); - } - - return getTimestamp(current); - } - - @Override - public byte[] getCurrentRecordId() throws NoSuchElementException { - if (current == null) { - throw new NoSuchElementException(); - } - return current.getMessageId().getBytes(UTF_8); - } - - @Override - public CheckpointMark getCheckpointMark() { - int cur = numInFlightCheckpoints.incrementAndGet(); - maxInFlightCheckpoints = Math.max(maxInFlightCheckpoints, cur); - List snapshotSafeToDeleteIds = Lists.newArrayList(safeToDeleteIds); - List snapshotNotYetReadReceipts = new ArrayList<>(messagesNotYetRead.size()); - for (Message message : messagesNotYetRead) { - snapshotNotYetReadReceipts.add(message.getReceiptHandle()); - } - return new SqsCheckpointMark(this, snapshotSafeToDeleteIds, snapshotNotYetReadReceipts); - } - - @Override - public SqsUnboundedSource getCurrentSource() { - return source; - } - - @Override - public long getTotalBacklogBytes() { - long avgBytes = avgMessageBytes(); - List requestAttributes = - Collections.singletonList(QueueAttributeName.ApproximateNumberOfMessages.toString()); - Map queueAttributes = - sqsClient - .getQueueAttributes(source.getRead().queueUrl(), requestAttributes) - .getAttributes(); - long numMessages = - Long.parseLong( - queueAttributes.get(QueueAttributeName.ApproximateNumberOfMessages.toString())); - - // No messages consumed for estimating average message size - if (avgBytes == -1 && numMessages > 0) { - return BACKLOG_UNKNOWN; - } else { - return numMessages * avgBytes; - } - } - - @Override - public boolean start() throws IOException { - initClient(); - visibilityTimeoutMs = - Integer.parseInt( - sqsClient - .getQueueAttributes( - new GetQueueAttributesRequest(source.getRead().queueUrl()) - .withAttributeNames("VisibilityTimeout")) - .getAttributes() - .get("VisibilityTimeout")) - * 1000L; - return advance(); - } - - private void initClient() { - if (sqsClient == null) { - sqsClient = - AmazonSQSClientBuilder.standard() - .withClientConfiguration(source.getSqsConfiguration().getClientConfiguration()) - .withCredentials(source.getSqsConfiguration().getAwsCredentialsProvider()) - .withRegion(source.getSqsConfiguration().getAwsRegion()) - .build(); - } - } - - @Override - public boolean advance() throws IOException { - // Emit stats. - stats(); - - if (current != null) { - // Current is consumed. It can no longer contribute to holding back the watermark. - minUnreadTimestampMsSinceEpoch.remove(getRequestTimeMsSinceEpoch(current)); - current = null; - } - - // Retire state associated with deleted messages. - retire(); - - // Extend all pressing deadlines. - // Will BLOCK until done. - // If the system is pulling messages only to let them sit in a downstream queue then - // this will have the effect of slowing down the pull rate. - // However, if the system is genuinely taking longer to process each message then - // the work to extend visibility timeout would be better done in the background. - extend(); - - if (messagesNotYetRead.isEmpty()) { - // Pull another batch. - // Will BLOCK until fetch returns, but will not block until a message is available. - pull(); - } - - // Take one message from queue. - current = messagesNotYetRead.poll(); - if (current == null) { - // Try again later. - return false; - } - notYetReadBytes -= current.getBody().getBytes(UTF_8).length; - checkState(notYetReadBytes >= 0); - long nowMsSinceEpoch = now(); - numReadBytes.add(nowMsSinceEpoch, current.getBody().getBytes(UTF_8).length); - recentMessageBytes.add(current.getBody().getBytes(UTF_8).length); - minReadTimestampMsSinceEpoch.add(nowMsSinceEpoch, getCurrentTimestamp().getMillis()); - if (getCurrentTimestamp().getMillis() < lastWatermarkMsSinceEpoch) { - numLateMessages.add(nowMsSinceEpoch, 1L); - } - - // Current message can be considered 'read' and will be persisted by the next - // checkpoint. So it is now safe to delete from SQS. - safeToDeleteIds.add(current.getMessageId()); - - return true; - } - - /** - * {@inheritDoc}. - * - *

Marks this {@link SqsUnboundedReader} as no longer active. The {@link AmazonSQS} continue to - * exist and be active beyond the life of this call if there are any in-flight checkpoints. When - * no in-flight checkpoints remain, the reader will be closed. - */ - @Override - public void close() throws IOException { - active.set(false); - maybeCloseClient(); - } - - /** - * Close this reader's underlying {@link AmazonSQS} if the reader has been closed and there are no - * outstanding checkpoints. - */ - void maybeCloseClient() throws IOException { - if (!active.get() && numInFlightCheckpoints.get() == 0) { - // The reader has been closed and it has no more outstanding checkpoints. The client - // must be closed so it doesn't leak - if (sqsClient != null) { - sqsClient.shutdown(); - } - } - } - - /** delete the provided {@code messageIds} from SQS. */ - void delete(List messageIds) throws IOException { - AtomicInteger counter = new AtomicInteger(); - for (List messageList : - messageIds.stream() - .collect(groupingBy(x -> counter.getAndIncrement() / DELETE_BATCH_SIZE)) - .values()) { - deleteBatch(messageList); - } - } - - /** - * delete the provided {@code messageIds} from SQS, blocking until all of the messages are - * deleted. - * - *

CAUTION: May be invoked from a separate thread. - * - *

CAUTION: Retains {@code messageIds}. - */ - private void deleteBatch(List messageIds) throws IOException { - int retries = 0; - List errorMessages = new ArrayList<>(); - Map pendingReceipts = - IntStream.range(0, messageIds.size()) - .boxed() - .filter(i -> inFlight.containsKey(messageIds.get(i))) - .collect(toMap(Object::toString, i -> inFlight.get(messageIds.get(i)).receiptHandle)); - - while (!pendingReceipts.isEmpty()) { - - if (retries >= BATCH_OPERATION_MAX_RETIRES) { - throw new IOException( - "Failed to delete " - + pendingReceipts.size() - + " messages after " - + retries - + " retries: " - + String.join(", ", errorMessages)); - } - - List entries = - pendingReceipts.entrySet().stream() - .map(r -> new DeleteMessageBatchRequestEntry(r.getKey(), r.getValue())) - .collect(Collectors.toList()); - - DeleteMessageBatchResult result = - sqsClient.deleteMessageBatch(source.getRead().queueUrl(), entries); - - // Retry errors except invalid handles - Set retryErrors = - result.getFailed().stream() - .filter(e -> !e.getCode().equals("ReceiptHandleIsInvalid")) - .collect(Collectors.toSet()); - - pendingReceipts - .keySet() - .retainAll( - retryErrors.stream().map(BatchResultErrorEntry::getId).collect(Collectors.toSet())); - - errorMessages = - retryErrors.stream().map(BatchResultErrorEntry::getMessage).collect(Collectors.toList()); - - retries += 1; - } - deletedIds.add(messageIds); - } - - /** - * Messages which have been deleted (via the checkpoint finalize) are no longer in flight. This is - * only used for flow control and stats. - */ - private void retire() { - long nowMsSinceEpoch = now(); - while (true) { - List ackIds = deletedIds.poll(); - if (ackIds == null) { - return; - } - numDeleted.add(nowMsSinceEpoch, ackIds.size()); - for (String ackId : ackIds) { - inFlight.remove(ackId); - safeToDeleteIds.remove(ackId); - } - } - } - - /** BLOCKING Fetch another batch of messages from SQS. */ - private void pull() { - if (inFlight.size() >= MAX_IN_FLIGHT) { - // Wait for checkpoint to be finalized before pulling anymore. - // There may be lag while checkpoints are persisted and the finalizeCheckpoint method - // is invoked. By limiting the in-flight messages we can ensure we don't end up consuming - // messages faster than we can checkpoint them. - return; - } - - long requestTimeMsSinceEpoch = now(); - long deadlineMsSinceEpoch = requestTimeMsSinceEpoch + visibilityTimeoutMs; - - final ReceiveMessageRequest receiveMessageRequest = - new ReceiveMessageRequest(source.getRead().queueUrl()); - - receiveMessageRequest.setMaxNumberOfMessages(MAX_NUMBER_OF_MESSAGES); - receiveMessageRequest.setAttributeNames( - Arrays.asList(MessageSystemAttributeName.SentTimestamp.toString())); - final ReceiveMessageResult receiveMessageResult = - sqsClient.receiveMessage(receiveMessageRequest); - - final List messages = receiveMessageResult.getMessages(); - - if (messages == null || messages.isEmpty()) { - numEmptyReceives.add(requestTimeMsSinceEpoch, 1L); - return; - } - - lastReceivedMsSinceEpoch = requestTimeMsSinceEpoch; - - // Capture the received messages. - for (Message message : messages) { - // Keep request time as message attribute for later usage - MessageAttributeValue reqTime = - new MessageAttributeValue().withStringValue(Long.toString(requestTimeMsSinceEpoch)); - message.setMessageAttributes(ImmutableMap.of(REQUEST_TIME, reqTime)); - messagesNotYetRead.add(message); - notYetReadBytes += message.getBody().getBytes(UTF_8).length; - inFlight.put( - message.getMessageId(), - new InFlightState( - message.getReceiptHandle(), requestTimeMsSinceEpoch, deadlineMsSinceEpoch)); - numReceived++; - numReceivedRecently.add(requestTimeMsSinceEpoch, 1L); - - long timestampMillis = getTimestamp(message).getMillis(); - minReceivedTimestampMsSinceEpoch.add(requestTimeMsSinceEpoch, timestampMillis); - maxReceivedTimestampMsSinceEpoch.add(requestTimeMsSinceEpoch, timestampMillis); - minUnreadTimestampMsSinceEpoch.add(requestTimeMsSinceEpoch, timestampMillis); - } - } - - /** Return the current time, in ms since epoch. */ - long now() { - return System.currentTimeMillis(); - } - - /** - * BLOCKING Extend deadline for all messages which need it. CAUTION: If extensions can't keep up - * with wallclock then we'll never return. - */ - private void extend() throws IOException { - while (true) { - long nowMsSinceEpoch = now(); - List assumeExpired = new ArrayList<>(); - List toBeExtended = new ArrayList<>(); - List toBeExpired = new ArrayList<>(); - // Messages will be in increasing deadline order. - for (Map.Entry entry : inFlight.entrySet()) { - if (entry.getValue().visibilityDeadlineMsSinceEpoch - - (visibilityTimeoutMs * VISIBILITY_SAFETY_PCT) / 100 - > nowMsSinceEpoch) { - // All remaining messages don't need their visibility timeouts to be extended. - break; - } - - if (entry.getValue().visibilityDeadlineMsSinceEpoch - VISIBILITY_TOO_LATE.getMillis() - < nowMsSinceEpoch) { - // SQS may have already considered this message to have expired. - // If so it will (eventually) be made available on a future pull request. - // If this message ends up being committed then it will be considered a duplicate - // when re-pulled. - assumeExpired.add(entry.getKey()); - continue; - } - - if (entry.getValue().requestTimeMsSinceEpoch + PROCESSING_TIMEOUT.getMillis() - < nowMsSinceEpoch) { - // This message has been in-flight for too long. - // Give up on it, otherwise we risk extending its visibility timeout indefinitely. - toBeExpired.add(entry.getKey()); - continue; - } - - // Extend the visibility timeout for this message. - toBeExtended.add(entry.getKey()); - if (toBeExtended.size() >= DELETE_BATCH_SIZE) { - // Enough for one batch. - break; - } - } - - if (assumeExpired.isEmpty() && toBeExtended.isEmpty() && toBeExpired.isEmpty()) { - // Nothing to be done. - return; - } - - if (!assumeExpired.isEmpty()) { - // If we didn't make the visibility deadline assume expired and no longer in flight. - numLateDeadlines.add(nowMsSinceEpoch, assumeExpired.size()); - for (String messageId : assumeExpired) { - inFlight.remove(messageId); - } - } - - if (!toBeExpired.isEmpty()) { - // Expired messages are no longer considered in flight. - numExpired.add(nowMsSinceEpoch, toBeExpired.size()); - for (String messageId : toBeExpired) { - inFlight.remove(messageId); - } - } - - if (!toBeExtended.isEmpty()) { - // SQS extends visibility timeout from it's notion of current time. - // We'll try to track that on our side, but note the deadlines won't necessarily agree. - long extensionMs = (int) ((visibilityTimeoutMs * VISIBILITY_EXTENSION_PCT) / 100L); - long newDeadlineMsSinceEpoch = nowMsSinceEpoch + extensionMs; - for (String messageId : toBeExtended) { - // Maintain increasing ack deadline order. - String receiptHandle = inFlight.get(messageId).receiptHandle; - InFlightState state = inFlight.remove(messageId); - - inFlight.put( - messageId, - new InFlightState( - receiptHandle, state.requestTimeMsSinceEpoch, newDeadlineMsSinceEpoch)); - } - List receiptHandles = - toBeExtended.stream() - .map(inFlight::get) - .filter(Objects::nonNull) // get rid of null values - .map(m -> m.receiptHandle) - .collect(Collectors.toList()); - // BLOCKs until extended. - extendBatch(nowMsSinceEpoch, receiptHandles, (int) (extensionMs / 1000)); - } - } - } - - /** - * BLOCKING Extend the visibility timeout for messages from SQS with the given {@code - * receiptHandles}. - */ - void extendBatch(long nowMsSinceEpoch, List receiptHandles, int extensionSec) - throws IOException { - int retries = 0; - int numMessages = receiptHandles.size(); - Map pendingReceipts = - IntStream.range(0, receiptHandles.size()) - .boxed() - .collect(toMap(Object::toString, receiptHandles::get)); - List errorMessages = new ArrayList<>(); - - while (!pendingReceipts.isEmpty()) { - - if (retries >= BATCH_OPERATION_MAX_RETIRES) { - throw new IOException( - "Failed to extend visibility timeout for " - + pendingReceipts.size() - + " messages after " - + retries - + " retries: " - + String.join(", ", errorMessages)); - } - - List entries = - pendingReceipts.entrySet().stream() - .map( - r -> - new ChangeMessageVisibilityBatchRequestEntry(r.getKey(), r.getValue()) - .withVisibilityTimeout(extensionSec)) - .collect(Collectors.toList()); - - ChangeMessageVisibilityBatchResult result = - sqsClient.changeMessageVisibilityBatch(source.getRead().queueUrl(), entries); - - // Retry errors except invalid handles - Set retryErrors = - result.getFailed().stream() - .filter(e -> !e.getCode().equals("ReceiptHandleIsInvalid")) - .collect(Collectors.toSet()); - - pendingReceipts - .keySet() - .retainAll( - retryErrors.stream().map(BatchResultErrorEntry::getId).collect(Collectors.toSet())); - - errorMessages = - retryErrors.stream().map(BatchResultErrorEntry::getMessage).collect(Collectors.toList()); - - retries += 1; - } - numExtendedDeadlines.add(nowMsSinceEpoch, numMessages); - } - - /** Log stats if time to do so. */ - private void stats() { - long nowMsSinceEpoch = now(); - if (lastLogTimestampMsSinceEpoch < 0) { - lastLogTimestampMsSinceEpoch = nowMsSinceEpoch; - return; - } - long deltaMs = nowMsSinceEpoch - lastLogTimestampMsSinceEpoch; - if (deltaMs < LOG_PERIOD.getMillis()) { - return; - } - - String messageSkew = "unknown"; - long minTimestamp = minReceivedTimestampMsSinceEpoch.get(nowMsSinceEpoch); - long maxTimestamp = maxReceivedTimestampMsSinceEpoch.get(nowMsSinceEpoch); - if (minTimestamp < Long.MAX_VALUE && maxTimestamp > Long.MIN_VALUE) { - messageSkew = (maxTimestamp - minTimestamp) + "ms"; - } - - String watermarkSkew = "unknown"; - long minWatermark = minWatermarkMsSinceEpoch.get(nowMsSinceEpoch); - long maxWatermark = maxWatermarkMsSinceEpoch.get(nowMsSinceEpoch); - if (minWatermark < Long.MAX_VALUE && maxWatermark > Long.MIN_VALUE) { - watermarkSkew = (maxWatermark - minWatermark) + "ms"; - } - - String oldestInFlight = "no"; - String oldestAckId = Iterables.getFirst(inFlight.keySet(), null); - if (oldestAckId != null) { - oldestInFlight = (nowMsSinceEpoch - inFlight.get(oldestAckId).requestTimeMsSinceEpoch) + "ms"; - } - - LOG.debug( - "SQS {} has " - + "{} received messages, " - + "{} current unread messages, " - + "{} current unread bytes, " - + "{} current in-flight msgs, " - + "{} oldest in-flight, " - + "{} current in-flight checkpoints, " - + "{} max in-flight checkpoints, " - + "{} bytes in backlog, " - + "{}B/s recent read, " - + "{} recent received, " - + "{} recent extended, " - + "{} recent late extended, " - + "{} recent deleted, " - + "{} recent released, " - + "{} recent expired, " - + "{} recent message timestamp skew, " - + "{} recent watermark skew, " - + "{} recent late messages, " - + "{} last reported watermark", - source.getRead().queueUrl(), - numReceived, - messagesNotYetRead.size(), - notYetReadBytes, - inFlight.size(), - oldestInFlight, - numInFlightCheckpoints.get(), - maxInFlightCheckpoints, - getTotalBacklogBytes(), - numReadBytes.get(nowMsSinceEpoch) / (SAMPLE_PERIOD.getMillis() / 1000L), - numReceivedRecently.get(nowMsSinceEpoch), - numExtendedDeadlines.get(nowMsSinceEpoch), - numLateDeadlines.get(nowMsSinceEpoch), - numDeleted.get(nowMsSinceEpoch), - numReleased.get(nowMsSinceEpoch), - numExpired.get(nowMsSinceEpoch), - messageSkew, - watermarkSkew, - numLateMessages.get(nowMsSinceEpoch), - new Instant(lastWatermarkMsSinceEpoch)); - - lastLogTimestampMsSinceEpoch = nowMsSinceEpoch; - } - - /** Return the average byte size of all message read. -1 if no message read yet */ - private long avgMessageBytes() { - if (!recentMessageBytes.isEmpty()) { - return (long) recentMessageBytes.stream().mapToDouble(s -> s).average().getAsDouble(); - } else { - return -1L; - } - } - - /** Extract the timestamp from the given {@code message}. */ - private Instant getTimestamp(final Message message) { - return new Instant( - Long.parseLong( - message.getAttributes().get(MessageSystemAttributeName.SentTimestamp.toString()))); - } - - /** Extract the request timestamp from the given {@code message}. */ - private Long getRequestTimeMsSinceEpoch(final Message message) { - return Long.parseLong(message.getMessageAttributes().get(REQUEST_TIME).getStringValue()); - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedSource.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedSource.java deleted file mode 100644 index 0ee9b8084179..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedSource.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import com.amazonaws.services.sqs.model.Message; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.SerializableCoder; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.sdk.io.aws.sqs.SqsIO.Read; -import org.apache.beam.sdk.options.PipelineOptions; -import org.checkerframework.checker.nullness.qual.Nullable; - -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class SqsUnboundedSource extends UnboundedSource { - - private final Read read; - private final SqsConfiguration sqsConfiguration; - private final Coder outputCoder; - - public SqsUnboundedSource( - Read read, SqsConfiguration sqsConfiguration, Coder outputCoder) { - this.read = read; - this.sqsConfiguration = sqsConfiguration; - this.outputCoder = outputCoder; - } - - @Override - public List split(int desiredNumSplits, PipelineOptions options) { - List sources = new ArrayList<>(); - for (int i = 0; i < Math.max(1, desiredNumSplits); ++i) { - sources.add(new SqsUnboundedSource(read, sqsConfiguration, outputCoder)); - } - return sources; - } - - @Override - public UnboundedReader createReader( - PipelineOptions options, @Nullable SqsCheckpointMark checkpointMark) { - try { - return new SqsUnboundedReader(this, checkpointMark); - } catch (IOException e) { - throw new RuntimeException("Unable to subscribe to " + read.queueUrl() + ": ", e); - } - } - - @Override - public Coder getCheckpointMarkCoder() { - return SerializableCoder.of(SqsCheckpointMark.class); - } - - @Override - public Coder getOutputCoder() { - return outputCoder; - } - - public Read getRead() { - return read; - } - - SqsConfiguration getSqsConfiguration() { - return sqsConfiguration; - } - - @Override - public boolean requiresDeduping() { - return true; - } -} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/package-info.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/package-info.java deleted file mode 100644 index d688641ddff6..000000000000 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sqs/package-info.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** Defines IO connectors for Amazon Web Services SQS. */ -package org.apache.beam.sdk.io.aws.sqs; diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/ITEnvironment.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/ITEnvironment.java deleted file mode 100644 index 3415a11bf9f0..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/ITEnvironment.java +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws; - -import static org.apache.beam.sdk.testing.TestPipeline.testingPipelineOptions; -import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3; - -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; -import org.apache.beam.sdk.io.aws.options.AwsOptions; -import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.Description; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.TestPipelineOptions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.commons.lang3.StringUtils; -import org.junit.rules.ExternalResource; -import org.slf4j.LoggerFactory; -import org.testcontainers.containers.localstack.LocalStackContainer; -import org.testcontainers.containers.localstack.LocalStackContainer.Service; -import org.testcontainers.containers.output.Slf4jLogConsumer; -import org.testcontainers.utility.DockerImageName; - -/** - * JUnit rule providing an integration testing environment for AWS as {@link ExternalResource}. - * - *

This rule is typically used as @ClassRule. It starts a Localstack container with the requested - * AWS service and provides matching {@link AwsOptions}. The usage of localstack can also be - * disabled using {@link ITOptions} pipeline options to run integration tests against AWS, for - * instance: - * - *

{@code
- * ./gradlew :sdks:java:io:amazon-web-services:integrationTest \
- *   --info \
- *   --tests "org.apache.beam.sdk.io.aws.s3.S3FileSystemIT" \
- *   -DintegrationTestPipelineOptions='["--awsRegion=eu-central-1","--useLocalstack=false"]'
- * }
- * - * @param The options type to use for the integration test. - */ -public class ITEnvironment extends ExternalResource { - private static final String LOCALSTACK = "localstack/localstack"; - private static final String LOCALSTACK_VERSION = "0.13.1"; - - public interface ITOptions extends AwsOptions, TestPipelineOptions { - @Description("Number of rows to write and read by the test") - @Default.Integer(1000) - Integer getNumberOfRows(); - - void setNumberOfRows(Integer count); - - @Description("Flag if to use localstack, enabled by default.") - @Default.Boolean(true) - Boolean getUseLocalstack(); - - void setUseLocalstack(Boolean useLocalstack); - - @Description("Localstack log level, e.g. trace, debug, info") - String getLocalstackLogLevel(); - - void setLocalstackLogLevel(String level); - } - - private final OptionsT options; - private final LocalStackContainer localstack; - - public ITEnvironment(Service service, Class optionsClass, String... env) { - this(new Service[] {service}, optionsClass, env); - } - - public ITEnvironment(Service[] services, Class optionsClass, String... env) { - localstack = - new LocalStackContainer(DockerImageName.parse(LOCALSTACK).withTag(LOCALSTACK_VERSION)) - .withServices(services) - .withStartupAttempts(3); - - PipelineOptionsFactory.register(optionsClass); - options = testingPipelineOptions().as(optionsClass); - - localstack.setEnv(ImmutableList.copyOf(env)); - if (options.getLocalstackLogLevel() != null) { - localstack - .withEnv("LS_LOG", options.getLocalstackLogLevel()) - .withLogConsumer( - new Slf4jLogConsumer(LoggerFactory.getLogger(StringUtils.join(services)))); - } - } - - public TestPipeline createTestPipeline() { - return TestPipeline.fromOptions(options); - } - - public , ClientT> ClientT buildClient( - BuilderT builder) { - if (options.getAwsServiceEndpoint() != null) { - builder.withEndpointConfiguration( - new EndpointConfiguration(options.getAwsServiceEndpoint(), options.getAwsRegion())); - } else { - builder.setRegion(options.getAwsRegion()); - } - return builder.withCredentials(options.getAwsCredentialsProvider()).build(); - } - - public OptionsT options() { - return options; - } - - @Override - protected void before() { - if (options.getUseLocalstack()) { - startLocalstack(); - } - } - - @Override - protected void after() { - localstack.stop(); // noop if not started - } - - /** Necessary setup for localstack environment. */ - private void startLocalstack() { - localstack.start(); - options.setAwsServiceEndpoint( - localstack.getEndpointOverride(S3).toString()); // service irrelevant - options.setAwsRegion(localstack.getRegion()); - options.setAwsCredentialsProvider( - new AWSStaticCredentialsProvider( - new BasicAWSCredentials(localstack.getAccessKey(), localstack.getSecretKey()))); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/coders/AwsCodersTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/coders/AwsCodersTest.java deleted file mode 100644 index 1ee20a6fa7ea..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/coders/AwsCodersTest.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.coders; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; - -import com.amazonaws.ResponseMetadata; -import com.amazonaws.http.HttpResponse; -import com.amazonaws.http.SdkHttpMetadata; -import java.util.UUID; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.junit.Test; - -/** Tests for AWS coders. */ -public class AwsCodersTest { - - @Test - public void testResponseMetadataDecodeEncodeEquals() throws Exception { - ResponseMetadata value = buildResponseMetadata(); - ResponseMetadata clone = CoderUtils.clone(AwsCoders.responseMetadata(), value); - assertThat(clone.getRequestId(), equalTo(value.getRequestId())); - } - - @Test - public void testSdkHttpMetadataDecodeEncodeEquals() throws Exception { - SdkHttpMetadata value = buildSdkHttpMetadata(); - SdkHttpMetadata clone = CoderUtils.clone(AwsCoders.sdkHttpMetadata(), value); - assertThat(clone.getHttpStatusCode(), equalTo(value.getHttpStatusCode())); - assertThat(clone.getHttpHeaders(), equalTo(value.getHttpHeaders())); - } - - @Test - public void testSdkHttpMetadataWithoutHeadersDecodeEncodeEquals() throws Exception { - SdkHttpMetadata value = buildSdkHttpMetadata(); - SdkHttpMetadata clone = CoderUtils.clone(AwsCoders.sdkHttpMetadataWithoutHeaders(), value); - assertThat(clone.getHttpStatusCode(), equalTo(value.getHttpStatusCode())); - assertThat(clone.getHttpHeaders().isEmpty(), equalTo(true)); - } - - private ResponseMetadata buildResponseMetadata() { - return new ResponseMetadata( - ImmutableMap.of(ResponseMetadata.AWS_REQUEST_ID, UUID.randomUUID().toString())); - } - - private SdkHttpMetadata buildSdkHttpMetadata() { - HttpResponse httpResponse = new HttpResponse(null, null); - httpResponse.setStatusCode(200); - httpResponse.addHeader("Content-Type", "application/json"); - return SdkHttpMetadata.from(httpResponse); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoderTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoderTest.java deleted file mode 100644 index 489feb7a87c9..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/AttributeValueCoderTest.java +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import com.amazonaws.services.dynamodbv2.model.AttributeValue; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.junit.Assert; -import org.junit.Test; - -/** Unit test cases for each type of AttributeValue to test encoding and decoding. */ -public class AttributeValueCoderTest { - - @Test - public void shouldPassForStringType() throws IOException { - AttributeValue expected = new AttributeValue(); - expected.setS("testing"); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForNumberType() throws IOException { - AttributeValue expected = new AttributeValue(); - expected.setN("123"); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForBooleanType() throws IOException { - AttributeValue expected = new AttributeValue(); - expected.setBOOL(false); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForByteArray() throws IOException { - AttributeValue expected = new AttributeValue(); - expected.setB(ByteBuffer.wrap("hello".getBytes(StandardCharsets.UTF_8))); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForListOfString() throws IOException { - AttributeValue expected = new AttributeValue(); - expected.setSS(ImmutableList.of("foo", "bar")); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForOneListOfNumber() throws IOException { - AttributeValue expected = new AttributeValue(); - expected.setNS(ImmutableList.of("123", "456")); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForOneListOfByteArray() throws IOException { - AttributeValue expected = new AttributeValue(); - expected.setBS( - ImmutableList.of( - ByteBuffer.wrap("mylistbyte1".getBytes(StandardCharsets.UTF_8)), - ByteBuffer.wrap("mylistbyte2".getBytes(StandardCharsets.UTF_8)))); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForListType() throws IOException { - AttributeValue expected = new AttributeValue(); - - List listAttr = new ArrayList<>(); - listAttr.add(new AttributeValue("innerMapValue1")); - listAttr.add(new AttributeValue().withN("8976234")); - - expected.setL(listAttr); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForMapType() throws IOException { - AttributeValue expected = new AttributeValue(); - - Map attrMap = new HashMap<>(); - attrMap.put("innerMapAttr1", new AttributeValue("innerMapValue1")); - attrMap.put( - "innerMapAttr2", - new AttributeValue().withB(ByteBuffer.wrap("8976234".getBytes(StandardCharsets.UTF_8)))); - - expected.setM(attrMap); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } - - @Test - public void shouldPassForNullType() throws IOException { - AttributeValue expected = new AttributeValue(); - expected.setNULL(true); - - AttributeValueCoder coder = AttributeValueCoder.of(); - ByteArrayOutputStream output = new ByteArrayOutputStream(); - coder.encode(expected, output); - - ByteArrayInputStream in = new ByteArrayInputStream(output.toByteArray()); - - AttributeValue actual = coder.decode(in); - - Assert.assertEquals(expected, actual); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOIT.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOIT.java deleted file mode 100644 index e3aa62450ce5..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOIT.java +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import static org.apache.beam.sdk.io.common.TestRow.getExpectedHashForRowCount; -import static org.apache.beam.sdk.values.TypeDescriptors.strings; -import static org.testcontainers.containers.localstack.LocalStackContainer.Service.DYNAMODB; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; -import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder; -import com.amazonaws.services.dynamodbv2.model.AttributeDefinition; -import com.amazonaws.services.dynamodbv2.model.AttributeValue; -import com.amazonaws.services.dynamodbv2.model.CreateTableRequest; -import com.amazonaws.services.dynamodbv2.model.KeySchemaElement; -import com.amazonaws.services.dynamodbv2.model.KeyType; -import com.amazonaws.services.dynamodbv2.model.ProvisionedThroughput; -import com.amazonaws.services.dynamodbv2.model.PutRequest; -import com.amazonaws.services.dynamodbv2.model.ScalarAttributeType; -import com.amazonaws.services.dynamodbv2.model.ScanRequest; -import com.amazonaws.services.dynamodbv2.model.TableStatus; -import com.amazonaws.services.dynamodbv2.model.WriteRequest; -import java.util.Map; -import org.apache.beam.sdk.io.GenerateSequence; -import org.apache.beam.sdk.io.aws.ITEnvironment; -import org.apache.beam.sdk.io.common.HashingFn; -import org.apache.beam.sdk.io.common.TestRow; -import org.apache.beam.sdk.io.common.TestRow.DeterministicallyConstructTestRowFn; -import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.Description; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Flatten; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.junit.ClassRule; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -/** - * Integration test to write and read from DynamoDB. - * - *

By default this runs against Localstack, but you can use {@link DynamoDBIOIT.ITOptions} to - * configure tests to run against AWS DynamoDB. - * - *

{@code
- * ./gradlew :sdks:java:io:amazon-web-services:integrationTest \
- *   --info \
- *   --tests "org.apache.beam.sdk.io.aws.dynamodb.DynamoDBIOIT" \
- *   -DintegrationTestPipelineOptions='["--awsRegion=eu-central-1","--useLocalstack=false"]'
- * }
- */ -public class DynamoDBIOIT { - public interface ITOptions extends ITEnvironment.ITOptions { - @Description("DynamoDB table name") - @Default.String("beam-dynamodbio-it") - String getDynamoDBTable(); - - void setDynamoDBTable(String value); - - @Description("DynamoDB total segments") - @Default.Integer(2) - Integer getDynamoDBSegments(); - - void setDynamoDBSegments(Integer segments); - - @Description("Create DynamoDB table. Enabled when using localstack") - @Default.Boolean(false) - Boolean getCreateTable(); - - void setCreateTable(Boolean createTable); - } - - private static final String COL_ID = "id"; - private static final String COL_NAME = "name"; - - @ClassRule - public static ITEnvironment env = - new ITEnvironment<>(DYNAMODB, ITOptions.class, "DYNAMODB_ERROR_PROBABILITY=0.1"); - - @Rule public TestPipeline pipelineWrite = env.createTestPipeline(); - @Rule public TestPipeline pipelineRead = env.createTestPipeline(); - @Rule public ExternalResource dbTable = CreateDbTable.optionally(env.options()); - - /** Test which write and then read data from DynamoDB. */ - @Test - public void testWriteThenRead() { - runWrite(); - runRead(); - } - - /** Write test dataset to DynamoDB. */ - private void runWrite() { - int rows = env.options().getNumberOfRows(); - pipelineWrite - .apply("Generate Sequence", GenerateSequence.from(0).to(rows)) - .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn())) - .apply( - "Write to DynamoDB", - DynamoDBIO.write() - .withAwsClientsProvider(clientProvider()) - .withWriteRequestMapperFn(row -> buildWriteRequest(row))); - pipelineWrite.run().waitUntilFinish(); - } - - /** Read test dataset from DynamoDB. */ - private void runRead() { - int rows = env.options().getNumberOfRows(); - PCollection> records = - pipelineRead - .apply( - "Read from DynamoDB", - DynamoDBIO.read() - .withAwsClientsProvider(clientProvider()) - .withScanRequestFn(in -> buildScanRequest()) - .items()) - .apply("Flatten result", Flatten.iterables()); - - PAssert.thatSingleton(records.apply("Count All", Count.globally())).isEqualTo((long) rows); - - PCollection consolidatedHashcode = - records - .apply(MapElements.into(strings()).via(record -> record.get(COL_NAME).getS())) - .apply("Hash records", Combine.globally(new HashingFn()).withoutDefaults()); - - PAssert.that(consolidatedHashcode).containsInAnyOrder(getExpectedHashForRowCount(rows)); - - pipelineRead.run().waitUntilFinish(); - } - - private AwsClientsProvider clientProvider() { - AWSCredentials credentials = env.options().getAwsCredentialsProvider().getCredentials(); - return new BasicDynamoDBProvider( - credentials.getAWSAccessKeyId(), - credentials.getAWSSecretKey(), - Regions.fromName(env.options().getAwsRegion()), - env.options().getAwsServiceEndpoint()); - } - - private static ScanRequest buildScanRequest() { - return new ScanRequest(env.options().getDynamoDBTable()) - .withTotalSegments(env.options().getDynamoDBSegments()); - } - - private static KV buildWriteRequest(TestRow row) { - AttributeValue id = new AttributeValue().withN(row.id().toString()); - AttributeValue name = new AttributeValue().withS(row.name()); - PutRequest req = new PutRequest(ImmutableMap.of(COL_ID, id, COL_NAME, name)); - return KV.of(env.options().getDynamoDBTable(), new WriteRequest().withPutRequest(req)); - } - - static class CreateDbTable extends ExternalResource { - static ExternalResource optionally(ITOptions opts) { - boolean create = opts.getCreateTable() || opts.getUseLocalstack(); - return create ? new CreateDbTable() : new ExternalResource() {}; - } - - private final String name = env.options().getDynamoDBTable(); - private final AmazonDynamoDB client = env.buildClient(AmazonDynamoDBClientBuilder.standard()); - - @Override - protected void before() throws Throwable { - CreateTableRequest request = - new CreateTableRequest() - .withTableName(name) - .withAttributeDefinitions( - attribute(COL_ID, ScalarAttributeType.N), - attribute(COL_NAME, ScalarAttributeType.S)) - .withKeySchema(keyElement(COL_ID, KeyType.HASH), keyElement(COL_NAME, KeyType.RANGE)) - .withProvisionedThroughput(new ProvisionedThroughput(1000L, 1000L)); - String status = client.createTable(request).getTableDescription().getTableStatus(); - int attempts = 10; - for (int i = 0; i <= attempts; ++i) { - if (status.equals(TableStatus.ACTIVE.toString())) { - return; - } - Thread.sleep(1000L); - status = client.describeTable(name).getTable().getTableStatus(); - } - throw new RuntimeException("Unable to initialize table"); - } - - @Override - protected void after() { - client.deleteTable(name); - client.shutdown(); - } - - private AttributeDefinition attribute(String name, ScalarAttributeType type) { - return new AttributeDefinition(name, type); - } - - private KeySchemaElement keyElement(String name, KeyType type) { - return new KeySchemaElement(name, type); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOReadTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOReadTest.java deleted file mode 100644 index 27e2a84076b7..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOReadTest.java +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import static java.lang.Math.min; -import static java.util.stream.Collectors.toList; -import static java.util.stream.IntStream.range; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables.getLast; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.newArrayList; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.transform; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; -import com.amazonaws.services.dynamodbv2.model.AttributeValue; -import com.amazonaws.services.dynamodbv2.model.ScanRequest; -import com.amazonaws.services.dynamodbv2.model.ScanResult; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.IntStream; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Flatten; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.mockito.ArgumentMatcher; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; - -@RunWith(MockitoJUnitRunner.class) -public class DynamoDBIOReadTest { - private static final String tableName = "Test"; - - @Rule public final TestPipeline pipeline = TestPipeline.create(); - @Rule public final ExpectedException thrown = ExpectedException.none(); - @Mock public AmazonDynamoDB client; - - @Test - public void testReadOneSegment() { - MockData mockData = new MockData(range(0, 10)); - mockData.mockScan(10, client); // 1 scan iteration - - PCollection>> actual = - pipeline.apply( - DynamoDBIO.>>read() - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withScanRequestFn( - in -> new ScanRequest().withTableName(tableName).withTotalSegments(1)) - .items()); - - PAssert.that(actual.apply(Count.globally())).containsInAnyOrder(1L); - PAssert.that(actual).containsInAnyOrder(mockData.getAllItems()); - - pipeline.run().waitUntilFinish(); - } - - @Test - public void testReadWithCustomLimit() { - final int requestedLimit = 100; - MockData mockData = new MockData(range(0, 10)); - mockData.mockScan(requestedLimit, client); // 1 scan iteration - - pipeline.apply( - DynamoDBIO.>>read() - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withScanRequestFn( - in -> - new ScanRequest() - .withTableName(tableName) - .withTotalSegments(1) - .withLimit(requestedLimit)) - .items()); - - pipeline.run().waitUntilFinish(); - - verify(client).scan(argThat((ScanRequest req) -> requestedLimit == req.getLimit())); - } - - @Test - public void testReadThreeSegments() { - MockData mockData = new MockData(range(0, 10), range(10, 20), range(20, 30)); - mockData.mockScan(10, client); // 1 scan iteration per segment - - PCollection>> actual = - pipeline.apply( - DynamoDBIO.>>read() - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withScanRequestFn( - in -> new ScanRequest().withTableName(tableName).withTotalSegments(3)) - .items()); - - PAssert.that(actual.apply(Count.globally())).containsInAnyOrder(3L); - PAssert.that(actual.apply(Flatten.iterables())).containsInAnyOrder(mockData.getAllItems()); - - pipeline.run().waitUntilFinish(); - } - - @Test - public void testReadWithStartKey() { - MockData mockData = new MockData(range(0, 10), range(20, 32)); - mockData.mockScan(5, client); // 2 + 3 scan iterations - - PCollection>> actual = - pipeline.apply( - DynamoDBIO.>>read() - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withScanRequestFn( - in -> new ScanRequest().withTableName(tableName).withTotalSegments(2)) - .items()); - - PAssert.that(actual.apply(Count.globally())).containsInAnyOrder(5L); - PAssert.that(actual.apply(Flatten.iterables())).containsInAnyOrder(mockData.getAllItems()); - - pipeline.run().waitUntilFinish(); - } - - @Test - public void testReadMissingScanRequestFn() { - pipeline.enableAbandonedNodeEnforcement(false); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("withScanRequestFn() is required"); - - pipeline.apply(DynamoDBIO.read().withAwsClientsProvider(StaticAwsClientsProvider.of(client))); - } - - @Test - public void testReadMissingAwsClientsProvider() { - pipeline.enableAbandonedNodeEnforcement(false); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("withAwsClientsProvider() is required"); - - pipeline.apply(DynamoDBIO.read().withScanRequestFn(in -> new ScanRequest())); - } - - @Test - public void testReadMissingTotalSegments() { - pipeline.enableAbandonedNodeEnforcement(false); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("TotalSegments is required with withScanRequestFn() and greater zero"); - - pipeline.apply( - DynamoDBIO.read() - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withScanRequestFn(in -> new ScanRequest())); - } - - @Test - public void testReadInvalidTotalSegments() { - pipeline.enableAbandonedNodeEnforcement(false); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("TotalSegments is required with withScanRequestFn() and greater zero"); - - pipeline.apply( - DynamoDBIO.read() - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withScanRequestFn(in -> new ScanRequest().withTotalSegments(0))); - } - - private static class MockData { - private final List> data; - - MockData(IntStream... segments) { - data = Arrays.stream(segments).map(ids -> newArrayList(ids.iterator())).collect(toList()); - } - - List> getAllItems() { - return data.stream().flatMap(ids -> ids.stream()).map(id -> item(id)).collect(toList()); - } - - void mockScan(int sizeLimit, AmazonDynamoDB mock) { - for (int segment = 0; segment < data.size(); segment++) { - List ids = data.get(segment); - - List> items = null; - Map startKey, lastKey; - for (int start = 0; start < ids.size(); start += sizeLimit) { - startKey = items != null ? getLast(items) : null; - items = transform(ids.subList(start, min(ids.size(), start + sizeLimit)), id -> item(id)); - lastKey = start + sizeLimit < ids.size() ? getLast(items) : null; - - when(mock.scan(argThat(matchesScanRequest(segment, startKey)))) - .thenReturn(new ScanResult().withItems(items).withLastEvaluatedKey(lastKey)); - } - } - } - - ArgumentMatcher matchesScanRequest( - Integer segment, Map startKey) { - return req -> - req != null - && segment.equals(req.getSegment()) - && Objects.equals(startKey, req.getExclusiveStartKey()); - } - } - - private static Map item(int id) { - return ImmutableMap.of( - "rangeKey", new AttributeValue().withN(String.valueOf(id)), - "hashKey", new AttributeValue().withS(String.valueOf(id))); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOWriteTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOWriteTest.java deleted file mode 100644 index e49276ed4c40..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/DynamoDBIOWriteTest.java +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import static java.util.stream.Collectors.toList; -import static java.util.stream.IntStream.range; -import static java.util.stream.IntStream.rangeClosed; -import static org.apache.beam.sdk.io.aws.dynamodb.DynamoDBIO.Write.WriteFn.RETRY_ERROR_LOG; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps.filterKeys; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps.transformValues; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.inOrder; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; -import com.amazonaws.services.dynamodbv2.model.AmazonDynamoDBException; -import com.amazonaws.services.dynamodbv2.model.AttributeValue; -import com.amazonaws.services.dynamodbv2.model.BatchWriteItemRequest; -import com.amazonaws.services.dynamodbv2.model.BatchWriteItemResult; -import com.amazonaws.services.dynamodbv2.model.DeleteRequest; -import com.amazonaws.services.dynamodbv2.model.PutRequest; -import com.amazonaws.services.dynamodbv2.model.WriteRequest; -import java.io.IOException; -import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.IntStream; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.coders.DefaultCoder; -import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; -import org.apache.beam.sdk.io.aws.dynamodb.DynamoDBIO.RetryConfiguration; -import org.apache.beam.sdk.io.aws.dynamodb.DynamoDBIO.Write.WriteFn; -import org.apache.beam.sdk.testing.ExpectedLogs; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.SerializableBiFunction; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.joda.time.Duration; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatcher; -import org.mockito.InOrder; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; -import org.slf4j.helpers.MessageFormatter; - -@RunWith(MockitoJUnitRunner.class) -public class DynamoDBIOWriteTest { - private static final String tableName = "Test"; - - @Rule public final TestPipeline pipeline = TestPipeline.create(); - @Rule public final ExpectedLogs writeFnLogs = ExpectedLogs.none(WriteFn.class); - @Rule public final ExpectedException thrown = ExpectedException.none(); - - @Mock public AmazonDynamoDB client; - - @Test - public void testWritePutItems() { - List items = Item.range(0, 100); - - Supplier> capturePuts = - captureBatchWrites(client, req -> req.getPutRequest().getItem()); - - PCollection output = - pipeline - .apply(Create.of(items)) - .apply( - DynamoDBIO.write() - .withWriteRequestMapperFn(putRequestMapper) - .withAwsClientsProvider(StaticAwsClientsProvider.of(client))); - - PAssert.that(output).empty(); - pipeline.run().waitUntilFinish(); - - assertThat(capturePuts.get()).containsExactlyInAnyOrderElementsOf(items); - } - - @Test - public void testWritePutItemsWithDuplicates() { - List items = Item.range(0, 100); - - Supplier>> captureRequests = - captureBatchWriteRequests(client, req -> req.getPutRequest().getItem()); - - pipeline - .apply(Create.of(items)) - // generate identical duplicates - .apply(ParDo.of(new AddDuplicatesDoFn(3, false))) - .apply( - DynamoDBIO.write() - .withWriteRequestMapperFn(putRequestMapper) - .withAwsClientsProvider(StaticAwsClientsProvider.of(client))); - - pipeline.run().waitUntilFinish(); - - List> requests = captureRequests.get(); - for (List reqItems : requests) { - assertThat(reqItems).doesNotHaveDuplicates(); // each request is free of duplicates - } - - assertThat(requests.stream().flatMap(List::stream)).containsAll(items); - } - - @Test - public void testWritePutItemsWithDuplicatesByKey() { - ImmutableList keys = ImmutableList.of("id"); - List items = Item.range(0, 100); - - Supplier>> captureRequests = - captureBatchWriteRequests(client, req -> req.getPutRequest().getItem()); - - pipeline - .apply(Create.of(items)) - // decorate duplicates so they are different - .apply(ParDo.of(new AddDuplicatesDoFn(3, true))) - .apply( - DynamoDBIO.write() - .withWriteRequestMapperFn(putRequestMapper) - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withDeduplicateKeys(keys)); - - pipeline.run().waitUntilFinish(); - - List> requests = captureRequests.get(); - for (List reqItems : requests) { - List keysOnly = - reqItems.stream() - .map(item -> new Item(filterKeys(item.entries, keys::contains))) - .collect(toList()); - assertThat(keysOnly).doesNotHaveDuplicates(); // each request is free of duplicates - } - - assertThat(requests.stream().flatMap(List::stream)).containsAll(items); - } - - @Test - public void testWriteDeleteItems() { - List items = Item.range(0, 100); - - Supplier> captureDeletes = - captureBatchWrites(client, req -> req.getDeleteRequest().getKey()); - - PCollection output = - pipeline - .apply(Create.of(items)) - .apply( - DynamoDBIO.write() - .withWriteRequestMapperFn(deleteRequestMapper) - .withAwsClientsProvider(StaticAwsClientsProvider.of(client))); - - PAssert.that(output).empty(); - pipeline.run().waitUntilFinish(); - - assertThat(captureDeletes.get()).hasSize(100); - assertThat(captureDeletes.get()).containsExactlyInAnyOrderElementsOf(items); - } - - @Test - public void testWriteDeleteItemsWithDuplicates() { - List items = Item.range(0, 100); - - Supplier>> captureRequests = - captureBatchWriteRequests(client, req -> req.getDeleteRequest().getKey()); - - pipeline - .apply(Create.of(items)) - // generate identical duplicates - .apply(ParDo.of(new AddDuplicatesDoFn(3, false))) - .apply( - DynamoDBIO.write() - .withWriteRequestMapperFn(deleteRequestMapper) - .withAwsClientsProvider(StaticAwsClientsProvider.of(client))); - - pipeline.run().waitUntilFinish(); - - List> requests = captureRequests.get(); - for (List reqItems : requests) { - assertThat(reqItems).doesNotHaveDuplicates(); // each request is free of duplicates - } - - assertThat(requests.stream().flatMap(List::stream)).containsAll(items); - } - - @Test - public void testWritePutItemsWithRetrySuccess() { - when(client.batchWriteItem(any(BatchWriteItemRequest.class))) - .thenThrow( - AmazonDynamoDBException.class, - AmazonDynamoDBException.class, - AmazonDynamoDBException.class) - .thenReturn(new BatchWriteItemResult().withUnprocessedItems(ImmutableMap.of())); - - pipeline - .apply(Create.of(Item.of(1))) - .apply( - "write", - DynamoDBIO.write() - .withWriteRequestMapperFn(putRequestMapper) - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withRetryConfiguration(try4Times)); - - PipelineResult result = pipeline.run(); - result.waitUntilFinish(); - - verify(client, times(4)).batchWriteItem(any(BatchWriteItemRequest.class)); - } - - @Test - public void testWritePutItemsWithPartialSuccess() { - List writes = putRequests(Item.range(0, 10)); - - when(client.batchWriteItem(any(BatchWriteItemRequest.class))) - .thenReturn(partialWriteSuccess(writes.subList(4, 10))) - .thenReturn(partialWriteSuccess(writes.subList(8, 10))) - .thenReturn(new BatchWriteItemResult().withUnprocessedItems(ImmutableMap.of())); - - pipeline - .apply(Create.of(10)) // number if items to produce - .apply(ParDo.of(new GenerateItems())) // 10 items in one bundle - .apply( - "write", - DynamoDBIO.write() - .withWriteRequestMapperFn(putRequestMapper) - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withRetryConfiguration(try4Times)); - - PipelineResult result = pipeline.run(); - result.waitUntilFinish(); - - verify(client, times(3)).batchWriteItem(any(BatchWriteItemRequest.class)); - - InOrder ordered = inOrder(client); - ordered.verify(client).batchWriteItem(argThat(matchWritesUnordered(writes))); - ordered.verify(client).batchWriteItem(argThat(matchWritesUnordered(writes.subList(4, 10)))); - ordered.verify(client).batchWriteItem(argThat(matchWritesUnordered(writes.subList(8, 10)))); - } - - @Test - public void testWritePutItemsWithRetryFailure() throws Throwable { - thrown.expect(IOException.class); - thrown.expectMessage("Error writing to DynamoDB"); - thrown.expectMessage("No more attempts allowed"); - - when(client.batchWriteItem(any(BatchWriteItemRequest.class))) - .thenThrow(AmazonDynamoDBException.class); - - pipeline - .apply(Create.of(Item.of(1))) - .apply( - DynamoDBIO.write() - .withWriteRequestMapperFn(putRequestMapper) - .withAwsClientsProvider(StaticAwsClientsProvider.of(client)) - .withRetryConfiguration(try4Times)); - - try { - pipeline.run().waitUntilFinish(); - } catch (final Pipeline.PipelineExecutionException e) { - verify(client, times(4)).batchWriteItem(any(BatchWriteItemRequest.class)); - writeFnLogs.verifyWarn(MessageFormatter.format(RETRY_ERROR_LOG, 4, "").getMessage()); - throw e.getCause(); - } - } - - @DefaultCoder(AvroCoder.class) - static class Item implements Serializable { - Map entries; - - private Item() {} - - private Item(Map entries) { - this.entries = entries; - } - - static Item of(int id) { - return new Item(ImmutableMap.of("id", String.valueOf(id))); - } - - static Item of(Map attributes) { - return new Item(ImmutableMap.copyOf(transformValues(attributes, a -> a.getS()))); - } - - static List range(int startInclusive, int endExclusive) { - return IntStream.range(startInclusive, endExclusive).mapToObj(Item::of).collect(toList()); - } - - Item withEntry(String key, String value) { - return new Item( - ImmutableMap.builder().putAll(entries).put(key, value).build()); - } - - Map attributeMap() { - return new HashMap<>(transformValues(entries, v -> new AttributeValue().withS(v))); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - return Objects.equals(entries, ((Item) o).entries); - } - - @Override - public int hashCode() { - return Objects.hash(entries); - } - - @Override - public String toString() { - return "Item" + entries; - } - } - - private Supplier>> captureBatchWriteRequests( - AmazonDynamoDB mock, Function> extractor) { - ArgumentCaptor reqCaptor = - ArgumentCaptor.forClass(BatchWriteItemRequest.class); - when(mock.batchWriteItem(reqCaptor.capture())) - .thenReturn(new BatchWriteItemResult().withUnprocessedItems(ImmutableMap.of())); - - return () -> - reqCaptor.getAllValues().stream() - .flatMap(req -> req.getRequestItems().values().stream()) - .map(writes -> writes.stream().map(extractor).map(Item::of).collect(toList())) - .collect(toList()); - } - - private Supplier> captureBatchWrites( - AmazonDynamoDB mock, Function> extractor) { - Supplier>> requests = captureBatchWriteRequests(mock, extractor); - return () -> requests.get().stream().flatMap(reqs -> reqs.stream()).collect(toList()); - } - - private static ArgumentMatcher matchWritesUnordered( - List writes) { - return (BatchWriteItemRequest req) -> - req != null - && req.getRequestItems().get(tableName).size() == writes.size() - && req.getRequestItems().get(tableName).containsAll(writes); - } - - private static BatchWriteItemResult partialWriteSuccess(List unprocessed) { - return new BatchWriteItemResult().withUnprocessedItems(ImmutableMap.of(tableName, unprocessed)); - } - - private static List putRequests(List items) { - return items.stream().map(putRequest).collect(toList()); - } - - private static Function putRequest = - item -> new WriteRequest().withPutRequest(new PutRequest().withItem(item.attributeMap())); - - private static Function deleteRequest = - key -> new WriteRequest().withDeleteRequest(new DeleteRequest().withKey(key.attributeMap())); - - private static SerializableFunction> putRequestMapper = - item -> KV.of(tableName, putRequest.apply(item)); - - private static SerializableFunction> deleteRequestMapper = - key -> KV.of(tableName, deleteRequest.apply(key)); - - private static RetryConfiguration try4Times = - RetryConfiguration.create(4, Duration.standardSeconds(1), Duration.millis(1)); - - private static class GenerateItems extends DoFn { - @ProcessElement - public void processElement(ProcessContext ctx) { - range(0, ctx.element()).forEach(i -> ctx.output(Item.of(i))); - } - } - - /** - * A DoFn that adds N duplicates to a bundle. The original is emitted last and is the only item - * kept if deduplicating appropriately. - */ - private static class AddDuplicatesDoFn extends DoFn { - private final int duplicates; - private final SerializableBiFunction decorator; - - AddDuplicatesDoFn(int duplicates, boolean decorate) { - this.duplicates = duplicates; - this.decorator = - decorate ? (item, i) -> item.withEntry("duplicate", i.toString()) : (item, i) -> item; - } - - @ProcessElement - public void processElement(ProcessContext ctx) { - Item original = ctx.element(); - rangeClosed(1, duplicates).forEach(i -> ctx.output(decorator.apply(original, i))); - ctx.output(original); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/StaticAwsClientsProvider.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/StaticAwsClientsProvider.java deleted file mode 100644 index d3f676cf1096..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/StaticAwsClientsProvider.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.dynamodb; - -import static java.util.Collections.synchronizedMap; - -import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; -import java.util.HashMap; -import java.util.Map; - -/** Client provider supporting unserializable clients such as mock instances for unit tests. */ -class StaticAwsClientsProvider implements AwsClientsProvider { - private static final Map clients = synchronizedMap(new HashMap<>()); - - private final int id; - private final transient boolean cleanup; - - private StaticAwsClientsProvider(AmazonDynamoDB client) { - this.id = System.identityHashCode(client); - this.cleanup = true; - } - - static AwsClientsProvider of(AmazonDynamoDB client) { - StaticAwsClientsProvider provider = new StaticAwsClientsProvider(client); - clients.put(provider.id, client); - return provider; - } - - @Override - public AmazonDynamoDB createDynamoDB() { - return clients.get(id); - } - - @Override - protected void finalize() { - if (cleanup) { - clients.remove(id); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsHttpClientConfigurationTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsHttpClientConfigurationTest.java deleted file mode 100644 index f535c4271ac9..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsHttpClientConfigurationTest.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.options; - -import static org.junit.Assert.assertEquals; - -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * {@link AwsHttpClientConfigurationTest}. Test to verify that aws http client configuration are - * correctly being set for the respective AWS services. - */ -@RunWith(JUnit4.class) -public class AwsHttpClientConfigurationTest { - - @Test - public void testAwsHttpClientConfigurationValues() { - S3Options s3Options = getOptions(); - assertEquals(5000, s3Options.getClientConfiguration().getSocketTimeout()); - assertEquals(1000, s3Options.getClientConfiguration().getClientExecutionTimeout()); - assertEquals(10, s3Options.getClientConfiguration().getMaxConnections()); - } - - private static S3Options getOptions() { - String[] args = { - "--s3ClientFactoryClass=org.apache.beam.sdk.io.aws.s3.DefaultS3ClientBuilderFactory", - "--clientConfiguration={\"clientExecutionTimeout\":1000," - + "\"maxConnections\":10," - + "\"socketTimeout\":5000}" - }; - return PipelineOptionsFactory.fromArgs(args).as(S3Options.class); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsModuleTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsModuleTest.java deleted file mode 100644 index 0099b08b7043..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/options/AwsModuleTest.java +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.options; - -import static org.apache.beam.repackaged.core.org.apache.commons.lang3.reflect.FieldUtils.readField; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.hasItem; -import static org.junit.Assert.assertEquals; - -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.BasicSessionCredentials; -import com.amazonaws.auth.ClasspathPropertiesFileCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.auth.EC2ContainerCredentialsProviderWrapper; -import com.amazonaws.auth.EnvironmentVariableCredentialsProvider; -import com.amazonaws.auth.PropertiesFileCredentialsProvider; -import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; -import com.amazonaws.auth.SystemPropertiesCredentialsProvider; -import com.amazonaws.auth.profile.ProfileCredentialsProvider; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import com.amazonaws.services.s3.model.SSECustomerKey; -import com.fasterxml.jackson.databind.Module; -import com.fasterxml.jackson.databind.ObjectMapper; -import java.util.List; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.util.common.ReflectHelpers; -import org.apache.beam.sdk.util.construction.PipelineOptionsTranslation; -import org.hamcrest.Matchers; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests {@link AwsModule}. */ -@RunWith(JUnit4.class) -public class AwsModuleTest { - - private final ObjectMapper objectMapper = new ObjectMapper().registerModule(new AwsModule()); - - @Test - public void testObjectMapperIsAbleToFindModule() { - List modules = ObjectMapper.findModules(ReflectHelpers.findClassLoader()); - assertThat(modules, hasItem(Matchers.instanceOf(AwsModule.class))); - } - - @Test - public void testAWSStaticCredentialsProviderSerializationDeserialization() throws Exception { - String awsKeyId = "key-id"; - String awsSecretKey = "secret-key"; - - AWSStaticCredentialsProvider credentialsProvider = - new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsKeyId, awsSecretKey)); - - String serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - AWSCredentialsProvider deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - assertEquals( - credentialsProvider.getCredentials().getAWSAccessKeyId(), - deserializedCredentialsProvider.getCredentials().getAWSAccessKeyId()); - assertEquals( - credentialsProvider.getCredentials().getAWSSecretKey(), - deserializedCredentialsProvider.getCredentials().getAWSSecretKey()); - - String sessionToken = "session-token"; - BasicSessionCredentials sessionCredentials = - new BasicSessionCredentials(awsKeyId, awsSecretKey, sessionToken); - credentialsProvider = new AWSStaticCredentialsProvider(sessionCredentials); - serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - BasicSessionCredentials deserializedCredentials = - (BasicSessionCredentials) deserializedCredentialsProvider.getCredentials(); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - assertEquals(deserializedCredentials.getAWSAccessKeyId(), awsKeyId); - assertEquals(deserializedCredentials.getAWSSecretKey(), awsSecretKey); - assertEquals(deserializedCredentials.getSessionToken(), sessionToken); - } - - @Test - public void testPropertiesFileCredentialsProviderSerializationDeserialization() throws Exception { - String credentialsFilePath = "/path/to/file"; - - PropertiesFileCredentialsProvider credentialsProvider = - new PropertiesFileCredentialsProvider(credentialsFilePath); - - String serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - AWSCredentialsProvider deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - assertEquals( - credentialsFilePath, - readField(deserializedCredentialsProvider, "credentialsFilePath", true)); - } - - @Test - public void testClasspathPropertiesFileCredentialsProviderSerializationDeserialization() - throws Exception { - String credentialsFilePath = "/path/to/file"; - - ClasspathPropertiesFileCredentialsProvider credentialsProvider = - new ClasspathPropertiesFileCredentialsProvider(credentialsFilePath); - - String serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - AWSCredentialsProvider deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - assertEquals( - credentialsFilePath, - readField(deserializedCredentialsProvider, "credentialsFilePath", true)); - } - - @Test - public void testSTSAssumeRoleSessionCredentialsProviderSerializationDeserialization() - throws Exception { - String roleArn = "arn:aws:iam::000111222333:role/TestRole"; - String roleSessionName = "roleSessionName"; - STSAssumeRoleSessionCredentialsProvider credentialsProvider = - new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, roleSessionName).build(); - String serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - AWSCredentialsProvider deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - assertEquals(roleArn, readField(deserializedCredentialsProvider, "roleArn", true)); - assertEquals( - roleSessionName, readField(deserializedCredentialsProvider, "roleSessionName", true)); - } - - @Test - public void testSingletonAWSCredentialsProviderSerializationDeserialization() throws Exception { - AWSCredentialsProvider credentialsProvider; - String serializedCredentialsProvider; - AWSCredentialsProvider deserializedCredentialsProvider; - - credentialsProvider = new DefaultAWSCredentialsProviderChain(); - serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - - credentialsProvider = new EnvironmentVariableCredentialsProvider(); - serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - - credentialsProvider = new SystemPropertiesCredentialsProvider(); - serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - - credentialsProvider = new ProfileCredentialsProvider(); - serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - - credentialsProvider = new EC2ContainerCredentialsProviderWrapper(); - serializedCredentialsProvider = objectMapper.writeValueAsString(credentialsProvider); - deserializedCredentialsProvider = - objectMapper.readValue(serializedCredentialsProvider, AWSCredentialsProvider.class); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - } - - @Test - public void testSSECustomerKeySerializationDeserialization() throws Exception { - final String key = "86glyTlCNZgccSxW8JxMa6ZdjdK3N141glAysPUZ3AA="; - final String md5 = null; - final String algorithm = "AES256"; - - SSECustomerKey value = new SSECustomerKey(key); - - String valueAsJson = objectMapper.writeValueAsString(value); - SSECustomerKey valueDes = objectMapper.readValue(valueAsJson, SSECustomerKey.class); - assertEquals(key, valueDes.getKey()); - assertEquals(algorithm, valueDes.getAlgorithm()); - assertEquals(md5, valueDes.getMd5()); - } - - @Test - public void testSSEAwsKeyManagementParamsSerializationDeserialization() throws Exception { - final String awsKmsKeyId = - "arn:aws:kms:eu-west-1:123456789012:key/dc123456-7890-ABCD-EF01-234567890ABC"; - final String encryption = "aws:kms"; - SSEAwsKeyManagementParams value = new SSEAwsKeyManagementParams(awsKmsKeyId); - - String valueAsJson = objectMapper.writeValueAsString(value); - SSEAwsKeyManagementParams valueDes = - objectMapper.readValue(valueAsJson, SSEAwsKeyManagementParams.class); - assertEquals(awsKmsKeyId, valueDes.getAwsKmsKeyId()); - assertEquals(encryption, valueDes.getEncryption()); - } - - @Test - public void testClientConfigurationSerializationDeserialization() throws Exception { - ClientConfiguration clientConfiguration = new ClientConfiguration(); - clientConfiguration.setProxyHost("localhost"); - clientConfiguration.setProxyPort(1234); - clientConfiguration.setProxyUsername("username"); - clientConfiguration.setProxyPassword("password"); - - final String valueAsJson = objectMapper.writeValueAsString(clientConfiguration); - final ClientConfiguration valueDes = - objectMapper.readValue(valueAsJson, ClientConfiguration.class); - assertEquals("localhost", valueDes.getProxyHost()); - assertEquals(1234, valueDes.getProxyPort()); - assertEquals("username", valueDes.getProxyUsername()); - assertEquals("password", valueDes.getProxyPassword()); - } - - @Test - public void testAwsHttpClientConfigurationSerializationDeserialization() throws Exception { - ClientConfiguration clientConfiguration = new ClientConfiguration(); - clientConfiguration.setConnectionTimeout(100); - clientConfiguration.setConnectionMaxIdleMillis(1000); - clientConfiguration.setSocketTimeout(300); - - final String valueAsJson = objectMapper.writeValueAsString(clientConfiguration); - final ClientConfiguration clientConfigurationDeserialized = - objectMapper.readValue(valueAsJson, ClientConfiguration.class); - assertEquals(100, clientConfigurationDeserialized.getConnectionTimeout()); - assertEquals(1000, clientConfigurationDeserialized.getConnectionMaxIdleMillis()); - assertEquals(300, clientConfigurationDeserialized.getSocketTimeout()); - } - - @Test - public void testAwsHttpClientConfigurationSerializationDeserializationProto() throws Exception { - AwsOptions awsOptions = - PipelineOptionsTranslation.fromProto( - PipelineOptionsTranslation.toProto( - PipelineOptionsFactory.fromArgs( - "--clientConfiguration={ \"connectionTimeout\": 100, \"connectionMaxIdleTime\": 1000, \"socketTimeout\": 300, \"proxyPort\": -1, \"requestTimeout\": 1500 }") - .create())) - .as(AwsOptions.class); - ClientConfiguration clientConfigurationDeserialized = awsOptions.getClientConfiguration(); - - assertEquals(100, clientConfigurationDeserialized.getConnectionTimeout()); - assertEquals(1000, clientConfigurationDeserialized.getConnectionMaxIdleMillis()); - assertEquals(300, clientConfigurationDeserialized.getSocketTimeout()); - assertEquals(-1, clientConfigurationDeserialized.getProxyPort()); - assertEquals(1500, clientConfigurationDeserialized.getRequestTimeout()); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/MatchResultMatcher.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/MatchResultMatcher.java deleted file mode 100644 index e6b127947df0..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/MatchResultMatcher.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import java.io.IOException; -import java.util.List; -import org.apache.beam.sdk.io.fs.MatchResult; -import org.apache.beam.sdk.io.fs.ResourceId; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.hamcrest.BaseMatcher; -import org.hamcrest.Description; -import org.hamcrest.Matcher; - -/** - * Hamcrest {@link Matcher} to match {@link MatchResult}. Necessary because {@link - * MatchResult#metadata()} throws an exception under normal circumstances. - */ -class MatchResultMatcher extends BaseMatcher { - - private final MatchResult.Status expectedStatus; - private final List expectedMetadata; - private final IOException expectedException; - - private MatchResultMatcher( - MatchResult.Status expectedStatus, - List expectedMetadata, - IOException expectedException) { - this.expectedStatus = checkNotNull(expectedStatus); - checkArgument((expectedMetadata == null) ^ (expectedException == null)); - this.expectedMetadata = expectedMetadata; - this.expectedException = expectedException; - } - - static MatchResultMatcher create(List expectedMetadata) { - return new MatchResultMatcher(MatchResult.Status.OK, expectedMetadata, null); - } - - private static MatchResultMatcher create(MatchResult.Metadata expectedMetadata) { - return create(ImmutableList.of(expectedMetadata)); - } - - static MatchResultMatcher create( - long sizeBytes, long lastModifiedMillis, ResourceId resourceId, boolean isReadSeekEfficient) { - return create( - MatchResult.Metadata.builder() - .setSizeBytes(sizeBytes) - .setLastModifiedMillis(lastModifiedMillis) - .setResourceId(resourceId) - .setIsReadSeekEfficient(isReadSeekEfficient) - .build()); - } - - static MatchResultMatcher create( - MatchResult.Status expectedStatus, IOException expectedException) { - return new MatchResultMatcher(expectedStatus, null, expectedException); - } - - static MatchResultMatcher create(MatchResult expected) { - MatchResult.Status expectedStatus = expected.status(); - List expectedMetadata = null; - IOException expectedException = null; - try { - expectedMetadata = expected.metadata(); - } catch (IOException e) { - expectedException = e; - } - return new MatchResultMatcher(expectedStatus, expectedMetadata, expectedException); - } - - @Override - public boolean matches(Object actual) { - if (actual == null) { - return false; - } - if (!(actual instanceof MatchResult)) { - return false; - } - MatchResult actualResult = (MatchResult) actual; - if (!expectedStatus.equals(actualResult.status())) { - return false; - } - - List actualMetadata; - try { - actualMetadata = actualResult.metadata(); - } catch (IOException e) { - return expectedException != null && expectedException.toString().equals(e.toString()); - } - return expectedMetadata != null && expectedMetadata.equals(actualMetadata); - } - - @Override - public void describeTo(Description description) { - if (expectedMetadata != null) { - description.appendText(MatchResult.create(expectedStatus, expectedMetadata).toString()); - } else { - description.appendText(MatchResult.create(expectedStatus, expectedException).toString()); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemIT.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemIT.java deleted file mode 100644 index 112ab95463b4..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemIT.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.sdk.io.common.TestRow.getExpectedHashForRowCount; -import static org.apache.commons.lang3.StringUtils.isAllLowerCase; -import static org.apache.http.HttpHeaders.CONTENT_LENGTH; -import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3; - -import com.amazonaws.Request; -import com.amazonaws.handlers.RequestHandler2; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import java.util.Map; -import org.apache.beam.sdk.io.GenerateSequence; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.io.aws.ITEnvironment; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.apache.beam.sdk.io.common.HashingFn; -import org.apache.beam.sdk.io.common.TestRow.DeterministicallyConstructTestRowFn; -import org.apache.beam.sdk.io.common.TestRow.SelectNameFn; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.joda.time.DateTime; -import org.junit.ClassRule; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Integration test to write and read from a S3 compatible file system. - * - *

By default this runs against Localstack, but you can use {@link S3FileSystemIT.S3ITOptions} to - * configure tests to run against AWS S3. - * - *

{@code
- * ./gradlew :sdks:java:io:amazon-web-services:integrationTest \
- *   --info \
- *   --tests "org.apache.beam.sdk.io.aws.s3.S3FileSystemIT" \
- *   -DintegrationTestPipelineOptions='["--awsRegion=eu-central-1","--useLocalstack=false"]'
- * }
- */ -@RunWith(JUnit4.class) -public class S3FileSystemIT { - public interface S3ITOptions extends ITEnvironment.ITOptions, S3Options {} - - @ClassRule - public static ITEnvironment env = - new ITEnvironment(S3, S3ITOptions.class) { - @Override - protected void before() { - super.before(); - options().setS3ClientFactoryClass(S3ClientFixFix.class); - } - }; - - @Rule public TestPipeline pipelineWrite = env.createTestPipeline(); - @Rule public TestPipeline pipelineRead = env.createTestPipeline(); - @Rule public S3Bucket s3Bucket = new S3Bucket(); - - @Test - public void testWriteThenRead() { - int rows = env.options().getNumberOfRows(); - // Write test dataset to S3. - pipelineWrite - .apply("Generate Sequence", GenerateSequence.from(0).to(rows)) - .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn())) - .apply("Prepare file rows", ParDo.of(new SelectNameFn())) - .apply("Write to S3 file", TextIO.write().to("s3://" + s3Bucket.name + "/test")); - - pipelineWrite.run().waitUntilFinish(); - - // Read test dataset from S3. - PCollection output = - pipelineRead.apply(TextIO.read().from("s3://" + s3Bucket.name + "/test*")); - - PAssert.thatSingleton(output.apply("Count All", Count.globally())).isEqualTo((long) rows); - - PAssert.that(output.apply(Combine.globally(new HashingFn()).withoutDefaults())) - .containsInAnyOrder(getExpectedHashForRowCount(rows)); - - pipelineRead.run().waitUntilFinish(); - } - - static class S3Bucket extends ExternalResource { - public final String name = "beam-s3io-it-" + new DateTime().toString("yyyyMMdd-HHmmss"); - - @Override - protected void before() { - AmazonS3 client = env.buildClient(AmazonS3ClientBuilder.standard()); - client.createBucket(name); - client.shutdown(); - } - } - - // Fix duplicated Content-Length header due to case-sensitive handling of header names - // https://github.com/aws/aws-sdk-java/issues/2503 - private static class S3ClientFixFix extends DefaultS3ClientBuilderFactory { - @Override - public AmazonS3ClientBuilder createBuilder(S3Options s3Options) { - return super.createBuilder(s3Options) - .withRequestHandlers( - new RequestHandler2() { - @Override - public void beforeRequest(Request request) { - Map headers = request.getHeaders(); - if (!isAllLowerCase(CONTENT_LENGTH) && headers.containsKey(CONTENT_LENGTH)) { - headers.remove(CONTENT_LENGTH.toLowerCase()); // remove duplicated header - } - } - }); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java deleted file mode 100644 index db749d7080e2..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java +++ /dev/null @@ -1,1248 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.buildMockedS3FileSystem; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3Config; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3ConfigWithCustomEndpointAndPathStyleAccessEnabled; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3ConfigWithSSECustomerKey; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3Options; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3OptionsWithCustomEndpointAndPathStyleAccessEnabled; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3OptionsWithSSECustomerKey; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.toMd5; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.contains; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Matchers.anyObject; -import static org.mockito.Matchers.notNull; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import akka.http.scaladsl.Http; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.AnonymousAWSCredentials; -import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.AmazonS3Exception; -import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest; -import com.amazonaws.services.s3.model.CopyObjectRequest; -import com.amazonaws.services.s3.model.CopyObjectResult; -import com.amazonaws.services.s3.model.CopyPartRequest; -import com.amazonaws.services.s3.model.CopyPartResult; -import com.amazonaws.services.s3.model.DeleteObjectsRequest; -import com.amazonaws.services.s3.model.GetObjectMetadataRequest; -import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest; -import com.amazonaws.services.s3.model.InitiateMultipartUploadResult; -import com.amazonaws.services.s3.model.ListObjectsV2Request; -import com.amazonaws.services.s3.model.ListObjectsV2Result; -import com.amazonaws.services.s3.model.ObjectMetadata; -import com.amazonaws.services.s3.model.S3ObjectSummary; -import com.amazonaws.services.s3.model.SSECustomerKey; -import io.findify.s3mock.S3Mock; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.net.URISyntaxException; -import java.net.URL; -import java.nio.ByteBuffer; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.WritableByteChannel; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.apache.beam.sdk.io.fs.CreateOptions; -import org.apache.beam.sdk.io.fs.MatchResult; -import org.apache.beam.sdk.metrics.Lineage; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentMatcher; - -/** Test case for {@link S3FileSystem}. */ -@RunWith(JUnit4.class) -public class S3FileSystemTest { - private static S3Mock api; - private static AmazonS3 client; - - @BeforeClass - public static void beforeClass() { - api = new S3Mock.Builder().withInMemoryBackend().build(); - Http.ServerBinding binding = api.start(); - - EndpointConfiguration endpoint = - new EndpointConfiguration( - "http://localhost:" + binding.localAddress().getPort(), "us-west-2"); - client = - AmazonS3ClientBuilder.standard() - .withPathStyleAccessEnabled(true) - .withEndpointConfiguration(endpoint) - .withCredentials(new AWSStaticCredentialsProvider(new AnonymousAWSCredentials())) - .build(); - } - - @AfterClass - public static void afterClass() { - api.stop(); - } - - @Test - public void testGetScheme() { - S3FileSystem s3FileSystem = new S3FileSystem(s3Config("s3")); - assertEquals("s3", s3FileSystem.getScheme()); - - s3FileSystem = new S3FileSystem(s3Config("other")); - assertEquals("other", s3FileSystem.getScheme()); - } - - @Test - public void testGetSchemeWithS3Options() { - S3FileSystem s3FileSystem = new S3FileSystem(s3Options()); - assertEquals("s3", s3FileSystem.getScheme()); - } - - @Test - public void testGetPathStyleAccessEnabled() throws URISyntaxException { - S3FileSystem s3FileSystem = - new S3FileSystem(s3ConfigWithCustomEndpointAndPathStyleAccessEnabled("s3")); - URL s3Url = s3FileSystem.getAmazonS3Client().getUrl("bucket", "file"); - assertEquals("https://s3.custom.dns/bucket/file", s3Url.toURI().toString()); - } - - @Test - public void testGetPathStyleAccessEnabledWithS3Options() throws URISyntaxException { - S3FileSystem s3FileSystem = - new S3FileSystem(s3OptionsWithCustomEndpointAndPathStyleAccessEnabled()); - URL s3Url = s3FileSystem.getAmazonS3Client().getUrl("bucket", "file"); - assertEquals("https://s3.custom.dns/bucket/file", s3Url.toURI().toString()); - } - - @Test - public void testCopy() throws IOException { - testCopy(s3Config("s3")); - testCopy(s3Config("other")); - testCopy(s3ConfigWithSSECustomerKey("s3")); - testCopy(s3ConfigWithSSECustomerKey("other")); - } - - @Test - public void testCopyWithS3Options() throws IOException { - testCopy(s3Options()); - testCopy(s3OptionsWithSSECustomerKey()); - } - - private GetObjectMetadataRequest createObjectMetadataRequest( - S3ResourceId path, SSECustomerKey sseCustomerKey) { - GetObjectMetadataRequest getObjectMetadataRequest = - new GetObjectMetadataRequest(path.getBucket(), path.getKey()); - getObjectMetadataRequest.setSSECustomerKey(sseCustomerKey); - return getObjectMetadataRequest; - } - - private void assertGetObjectMetadata( - S3FileSystem s3FileSystem, - GetObjectMetadataRequest request, - String sseCustomerKeyMd5, - ObjectMetadata objectMetadata) { - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata(argThat(new GetObjectMetadataRequestMatcher(request)))) - .thenReturn(objectMetadata); - assertEquals( - sseCustomerKeyMd5, - s3FileSystem.getAmazonS3Client().getObjectMetadata(request).getSSECustomerKeyMd5()); - } - - private void testCopy(S3FileSystemConfiguration config) throws IOException { - testCopy(buildMockedS3FileSystem(config), config.getSSECustomerKey()); - } - - private void testCopy(S3Options options) throws IOException { - testCopy(buildMockedS3FileSystem(options), options.getSSECustomerKey()); - } - - private void testCopy(S3FileSystem s3FileSystem, SSECustomerKey sseCustomerKey) - throws IOException { - S3ResourceId sourcePath = S3ResourceId.fromUri(s3FileSystem.getScheme() + "://bucket/from"); - S3ResourceId destinationPath = S3ResourceId.fromUri(s3FileSystem.getScheme() + "://bucket/to"); - - ObjectMetadata objectMetadata = new ObjectMetadata(); - objectMetadata.setContentLength(0); - String sseCustomerKeyMd5 = toMd5(sseCustomerKey); - if (sseCustomerKeyMd5 != null) { - objectMetadata.setSSECustomerKeyMd5(sseCustomerKeyMd5); - } - assertGetObjectMetadata( - s3FileSystem, - createObjectMetadataRequest(sourcePath, sseCustomerKey), - sseCustomerKeyMd5, - objectMetadata); - - s3FileSystem.copy(sourcePath, destinationPath); - - verify(s3FileSystem.getAmazonS3Client(), times(1)).copyObject(any(CopyObjectRequest.class)); - - // we simulate a big object >= 5GB so it takes the multiPart path - objectMetadata.setContentLength(5_368_709_120L); - assertGetObjectMetadata( - s3FileSystem, - createObjectMetadataRequest(sourcePath, sseCustomerKey), - sseCustomerKeyMd5, - objectMetadata); - - try { - s3FileSystem.copy(sourcePath, destinationPath); - } catch (NullPointerException e) { - // ignore failing unmocked path, this is covered by testMultipartCopy test - } - - verify(s3FileSystem.getAmazonS3Client(), never()).copyObject(null); - } - - @Test - public void testAtomicCopy() { - testAtomicCopy(s3Config("s3")); - testAtomicCopy(s3Config("other")); - testAtomicCopy(s3ConfigWithSSECustomerKey("s3")); - testAtomicCopy(s3ConfigWithSSECustomerKey("other")); - } - - @Test - public void testAtomicCopyWithS3Options() { - testAtomicCopy(s3Options()); - testAtomicCopy(s3OptionsWithSSECustomerKey()); - } - - private void testAtomicCopy(S3FileSystemConfiguration config) { - testAtomicCopy(buildMockedS3FileSystem(config), config.getSSECustomerKey()); - } - - private void testAtomicCopy(S3Options options) { - testAtomicCopy(buildMockedS3FileSystem(options), options.getSSECustomerKey()); - } - - private void testAtomicCopy(S3FileSystem s3FileSystem, SSECustomerKey sseCustomerKey) { - S3ResourceId sourcePath = S3ResourceId.fromUri(s3FileSystem.getScheme() + "://bucket/from"); - S3ResourceId destinationPath = S3ResourceId.fromUri(s3FileSystem.getScheme() + "://bucket/to"); - - CopyObjectResult copyObjectResult = new CopyObjectResult(); - String sseCustomerKeyMd5 = toMd5(sseCustomerKey); - if (sseCustomerKeyMd5 != null) { - copyObjectResult.setSSECustomerKeyMd5(sseCustomerKeyMd5); - } - CopyObjectRequest copyObjectRequest = - new CopyObjectRequest( - sourcePath.getBucket(), - sourcePath.getKey(), - destinationPath.getBucket(), - destinationPath.getKey()); - copyObjectRequest.setSourceSSECustomerKey(sseCustomerKey); - copyObjectRequest.setDestinationSSECustomerKey(sseCustomerKey); - when(s3FileSystem.getAmazonS3Client().copyObject(any(CopyObjectRequest.class))) - .thenReturn(copyObjectResult); - assertEquals( - sseCustomerKeyMd5, - s3FileSystem.getAmazonS3Client().copyObject(copyObjectRequest).getSSECustomerKeyMd5()); - - ObjectMetadata sourceS3ObjectMetadata = new ObjectMetadata(); - s3FileSystem.atomicCopy(sourcePath, destinationPath, sourceS3ObjectMetadata); - - verify(s3FileSystem.getAmazonS3Client(), times(2)).copyObject(any(CopyObjectRequest.class)); - } - - @Test - public void testMultipartCopy() { - testMultipartCopy(s3Config("s3")); - testMultipartCopy(s3Config("other")); - testMultipartCopy(s3ConfigWithSSECustomerKey("s3")); - testMultipartCopy(s3ConfigWithSSECustomerKey("other")); - } - - @Test - public void testMultipartCopyWithS3Options() { - testMultipartCopy(s3Options()); - testMultipartCopy(s3OptionsWithSSECustomerKey()); - } - - private void testMultipartCopy(S3FileSystemConfiguration config) { - testMultipartCopy( - buildMockedS3FileSystem(config), - config.getSSECustomerKey(), - config.getS3UploadBufferSizeBytes()); - } - - private void testMultipartCopy(S3Options options) { - testMultipartCopy( - buildMockedS3FileSystem(options), - options.getSSECustomerKey(), - options.getS3UploadBufferSizeBytes()); - } - - private void testMultipartCopy( - S3FileSystem s3FileSystem, SSECustomerKey sseCustomerKey, long s3UploadBufferSizeBytes) { - S3ResourceId sourcePath = S3ResourceId.fromUri(s3FileSystem.getScheme() + "://bucket/from"); - S3ResourceId destinationPath = S3ResourceId.fromUri(s3FileSystem.getScheme() + "://bucket/to"); - - InitiateMultipartUploadResult initiateMultipartUploadResult = - new InitiateMultipartUploadResult(); - initiateMultipartUploadResult.setUploadId("upload-id"); - String sseCustomerKeyMd5 = toMd5(sseCustomerKey); - if (sseCustomerKeyMd5 != null) { - initiateMultipartUploadResult.setSSECustomerKeyMd5(sseCustomerKeyMd5); - } - when(s3FileSystem - .getAmazonS3Client() - .initiateMultipartUpload(any(InitiateMultipartUploadRequest.class))) - .thenReturn(initiateMultipartUploadResult); - assertEquals( - sseCustomerKeyMd5, - s3FileSystem - .getAmazonS3Client() - .initiateMultipartUpload( - new InitiateMultipartUploadRequest( - destinationPath.getBucket(), destinationPath.getKey())) - .getSSECustomerKeyMd5()); - - ObjectMetadata sourceObjectMetadata = new ObjectMetadata(); - sourceObjectMetadata.setContentLength((long) (s3UploadBufferSizeBytes * 1.5)); - sourceObjectMetadata.setContentEncoding("read-seek-efficient"); - if (sseCustomerKeyMd5 != null) { - sourceObjectMetadata.setSSECustomerKeyMd5(sseCustomerKeyMd5); - } - assertGetObjectMetadata( - s3FileSystem, - createObjectMetadataRequest(sourcePath, sseCustomerKey), - sseCustomerKeyMd5, - sourceObjectMetadata); - - CopyPartResult copyPartResult1 = new CopyPartResult(); - copyPartResult1.setETag("etag-1"); - CopyPartResult copyPartResult2 = new CopyPartResult(); - copyPartResult1.setETag("etag-2"); - if (sseCustomerKeyMd5 != null) { - copyPartResult1.setSSECustomerKeyMd5(sseCustomerKeyMd5); - copyPartResult2.setSSECustomerKeyMd5(sseCustomerKeyMd5); - } - CopyPartRequest copyPartRequest = new CopyPartRequest(); - copyPartRequest.setSourceSSECustomerKey(sseCustomerKey); - when(s3FileSystem.getAmazonS3Client().copyPart(any(CopyPartRequest.class))) - .thenReturn(copyPartResult1) - .thenReturn(copyPartResult2); - assertEquals( - sseCustomerKeyMd5, - s3FileSystem.getAmazonS3Client().copyPart(copyPartRequest).getSSECustomerKeyMd5()); - - s3FileSystem.multipartCopy(sourcePath, destinationPath, sourceObjectMetadata); - - verify(s3FileSystem.getAmazonS3Client(), times(1)) - .completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); - } - - @Test - public void deleteThousandsOfObjectsInMultipleBuckets() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("s3")); - - List buckets = ImmutableList.of("bucket1", "bucket2"); - List keys = new ArrayList<>(); - for (int i = 0; i < 2500; i++) { - keys.add(String.format("key-%d", i)); - } - List paths = new ArrayList<>(); - for (String bucket : buckets) { - for (String key : keys) { - paths.add(S3ResourceId.fromComponents("s3", bucket, key)); - } - } - - s3FileSystem.delete(paths); - - // Should require 6 calls to delete 2500 objects in each of 2 buckets. - verify(s3FileSystem.getAmazonS3Client(), times(6)) - .deleteObjects(any(DeleteObjectsRequest.class)); - } - - @Test - public void deleteThousandsOfObjectsInMultipleBucketsWithS3Options() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - List buckets = ImmutableList.of("bucket1", "bucket2"); - List keys = new ArrayList<>(); - for (int i = 0; i < 2500; i++) { - keys.add(String.format("key-%d", i)); - } - List paths = new ArrayList<>(); - for (String bucket : buckets) { - for (String key : keys) { - paths.add(S3ResourceId.fromComponents("s3", bucket, key)); - } - } - - s3FileSystem.delete(paths); - - // Should require 6 calls to delete 2500 objects in each of 2 buckets. - verify(s3FileSystem.getAmazonS3Client(), times(6)) - .deleteObjects(any(DeleteObjectsRequest.class)); - } - - @Test - public void matchNonGlob() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("mys3")); - - S3ResourceId path = S3ResourceId.fromUri("mys3://testbucket/testdirectory/filethatexists"); - long lastModifiedMillis = 1540000000000L; - ObjectMetadata s3ObjectMetadata = new ObjectMetadata(); - s3ObjectMetadata.setContentLength(100); - s3ObjectMetadata.setContentEncoding("read-seek-efficient"); - s3ObjectMetadata.setLastModified(new Date(lastModifiedMillis)); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenReturn(s3ObjectMetadata); - - MatchResult result = s3FileSystem.matchNonGlobPath(path); - assertThat( - result, - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setSizeBytes(100) - .setLastModifiedMillis(lastModifiedMillis) - .setResourceId(path) - .setIsReadSeekEfficient(true) - .build()))); - } - - @Test - public void matchNonGlobWithS3Options() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/testdirectory/filethatexists"); - long lastModifiedMillis = 1540000000000L; - ObjectMetadata s3ObjectMetadata = new ObjectMetadata(); - s3ObjectMetadata.setContentLength(100); - s3ObjectMetadata.setContentEncoding("read-seek-efficient"); - s3ObjectMetadata.setLastModified(new Date(lastModifiedMillis)); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenReturn(s3ObjectMetadata); - - MatchResult result = s3FileSystem.matchNonGlobPath(path); - assertThat( - result, - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setSizeBytes(100) - .setLastModifiedMillis(lastModifiedMillis) - .setResourceId(path) - .setIsReadSeekEfficient(true) - .build()))); - } - - @Test - public void matchNonGlobNotReadSeekEfficient() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("s3")); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/testdirectory/filethatexists"); - long lastModifiedMillis = 1540000000000L; - ObjectMetadata s3ObjectMetadata = new ObjectMetadata(); - s3ObjectMetadata.setContentLength(100); - s3ObjectMetadata.setLastModified(new Date(lastModifiedMillis)); - s3ObjectMetadata.setContentEncoding("gzip"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenReturn(s3ObjectMetadata); - - MatchResult result = s3FileSystem.matchNonGlobPath(path); - assertThat( - result, - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setSizeBytes(100) - .setLastModifiedMillis(lastModifiedMillis) - .setResourceId(path) - .setIsReadSeekEfficient(false) - .build()))); - } - - @Test - public void matchNonGlobNotReadSeekEfficientWithS3Options() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/testdirectory/filethatexists"); - long lastModifiedMillis = 1540000000000L; - ObjectMetadata s3ObjectMetadata = new ObjectMetadata(); - s3ObjectMetadata.setContentLength(100); - s3ObjectMetadata.setLastModified(new Date(lastModifiedMillis)); - s3ObjectMetadata.setContentEncoding("gzip"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenReturn(s3ObjectMetadata); - - MatchResult result = s3FileSystem.matchNonGlobPath(path); - assertThat( - result, - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setSizeBytes(100) - .setLastModifiedMillis(lastModifiedMillis) - .setResourceId(path) - .setIsReadSeekEfficient(false) - .build()))); - } - - @Test - public void matchNonGlobNullContentEncoding() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("s3")); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/testdirectory/filethatexists"); - long lastModifiedMillis = 1540000000000L; - ObjectMetadata s3ObjectMetadata = new ObjectMetadata(); - s3ObjectMetadata.setContentLength(100); - s3ObjectMetadata.setLastModified(new Date(lastModifiedMillis)); - s3ObjectMetadata.setContentEncoding(null); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenReturn(s3ObjectMetadata); - - MatchResult result = s3FileSystem.matchNonGlobPath(path); - assertThat( - result, - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setSizeBytes(100) - .setLastModifiedMillis(lastModifiedMillis) - .setResourceId(path) - .setIsReadSeekEfficient(true) - .build()))); - } - - @Test - public void matchNonGlobNullContentEncodingWithS3Options() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/testdirectory/filethatexists"); - long lastModifiedMillis = 1540000000000L; - ObjectMetadata s3ObjectMetadata = new ObjectMetadata(); - s3ObjectMetadata.setContentLength(100); - s3ObjectMetadata.setLastModified(new Date(lastModifiedMillis)); - s3ObjectMetadata.setContentEncoding(null); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenReturn(s3ObjectMetadata); - - MatchResult result = s3FileSystem.matchNonGlobPath(path); - assertThat( - result, - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setSizeBytes(100) - .setLastModifiedMillis(lastModifiedMillis) - .setResourceId(path) - .setIsReadSeekEfficient(true) - .build()))); - } - - @Test - public void matchNonGlobNotFound() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("mys3")); - - S3ResourceId path = S3ResourceId.fromUri("mys3://testbucket/testdirectory/nonexistentfile"); - AmazonS3Exception exception = new AmazonS3Exception("mock exception"); - exception.setStatusCode(404); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenThrow(exception); - - MatchResult result = s3FileSystem.matchNonGlobPath(path); - assertThat( - result, - MatchResultMatcher.create(MatchResult.Status.NOT_FOUND, new FileNotFoundException())); - } - - @Test - public void matchNonGlobNotFoundWithS3Options() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/testdirectory/nonexistentfile"); - AmazonS3Exception exception = new AmazonS3Exception("mock exception"); - exception.setStatusCode(404); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenThrow(exception); - - MatchResult result = s3FileSystem.matchNonGlobPath(path); - assertThat( - result, - MatchResultMatcher.create(MatchResult.Status.NOT_FOUND, new FileNotFoundException())); - } - - @Test - public void matchNonGlobForbidden() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("s3")); - - AmazonS3Exception exception = new AmazonS3Exception("mock exception"); - exception.setStatusCode(403); - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/testdirectory/keyname"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenThrow(exception); - - assertThat( - s3FileSystem.matchNonGlobPath(path), - MatchResultMatcher.create(MatchResult.Status.ERROR, new IOException(exception))); - } - - @Test - public void matchNonGlobForbiddenWithS3Options() { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - AmazonS3Exception exception = new AmazonS3Exception("mock exception"); - exception.setStatusCode(403); - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/testdirectory/keyname"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(path.getBucket(), path.getKey()))))) - .thenThrow(exception); - - assertThat( - s3FileSystem.matchNonGlobPath(path), - MatchResultMatcher.create(MatchResult.Status.ERROR, new IOException(exception))); - } - - static class ListObjectsV2RequestArgumentMatches - implements ArgumentMatcher { - private final ListObjectsV2Request expected; - - ListObjectsV2RequestArgumentMatches(ListObjectsV2Request expected) { - this.expected = checkNotNull(expected); - } - - @Override - public boolean matches(ListObjectsV2Request argument) { - if (argument instanceof ListObjectsV2Request) { - ListObjectsV2Request actual = (ListObjectsV2Request) argument; - return expected.getBucketName().equals(actual.getBucketName()) - && expected.getPrefix().equals(actual.getPrefix()) - && (expected.getContinuationToken() == null - ? actual.getContinuationToken() == null - : expected.getContinuationToken().equals(actual.getContinuationToken())); - } - return false; - } - } - - @Test - public void matchGlob() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("mys3")); - - S3ResourceId path = S3ResourceId.fromUri("mys3://testbucket/foo/bar*baz"); - - ListObjectsV2Request firstRequest = - new ListObjectsV2Request() - .withBucketName(path.getBucket()) - .withPrefix(path.getKeyNonWildcardPrefix()) - .withContinuationToken(null); - - // Expected to be returned; prefix and wildcard/regex match - S3ObjectSummary firstMatch = new S3ObjectSummary(); - firstMatch.setBucketName(path.getBucket()); - firstMatch.setKey("foo/bar0baz"); - firstMatch.setSize(100); - firstMatch.setLastModified(new Date(1540000000001L)); - - // Expected to not be returned; prefix matches, but substring after wildcard does not - S3ObjectSummary secondMatch = new S3ObjectSummary(); - secondMatch.setBucketName(path.getBucket()); - secondMatch.setKey("foo/bar1qux"); - secondMatch.setSize(200); - secondMatch.setLastModified(new Date(1540000000002L)); - - // Expected first request returns continuation token - ListObjectsV2Result firstResult = new ListObjectsV2Result(); - firstResult.setNextContinuationToken("token"); - firstResult.getObjectSummaries().add(firstMatch); - firstResult.getObjectSummaries().add(secondMatch); - when(s3FileSystem - .getAmazonS3Client() - .listObjectsV2(argThat(new ListObjectsV2RequestArgumentMatches(firstRequest)))) - .thenReturn(firstResult); - - // Expect second request with continuation token - ListObjectsV2Request secondRequest = - new ListObjectsV2Request() - .withBucketName(path.getBucket()) - .withPrefix(path.getKeyNonWildcardPrefix()) - .withContinuationToken("token"); - - // Expected to be returned; prefix and wildcard/regex match - S3ObjectSummary thirdMatch = new S3ObjectSummary(); - thirdMatch.setBucketName(path.getBucket()); - thirdMatch.setKey("foo/bar2baz"); - thirdMatch.setSize(300); - thirdMatch.setLastModified(new Date(1540000000003L)); - - // Expected second request returns third prefix match and no continuation token - ListObjectsV2Result secondResult = new ListObjectsV2Result(); - secondResult.setNextContinuationToken(null); - secondResult.getObjectSummaries().add(thirdMatch); - when(s3FileSystem - .getAmazonS3Client() - .listObjectsV2(argThat(new ListObjectsV2RequestArgumentMatches(secondRequest)))) - .thenReturn(secondResult); - - // Expect object metadata queries for content encoding - ObjectMetadata metadata = new ObjectMetadata(); - metadata.setContentEncoding(""); - when(s3FileSystem.getAmazonS3Client().getObjectMetadata(anyObject())).thenReturn(metadata); - - assertThat( - s3FileSystem.matchGlobPaths(ImmutableList.of(path)).get(0), - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setIsReadSeekEfficient(true) - .setResourceId( - S3ResourceId.fromComponents( - "mys3", firstMatch.getBucketName(), firstMatch.getKey())) - .setSizeBytes(firstMatch.getSize()) - .setLastModifiedMillis(firstMatch.getLastModified().getTime()) - .build(), - MatchResult.Metadata.builder() - .setIsReadSeekEfficient(true) - .setResourceId( - S3ResourceId.fromComponents( - "mys3", thirdMatch.getBucketName(), thirdMatch.getKey())) - .setSizeBytes(thirdMatch.getSize()) - .setLastModifiedMillis(thirdMatch.getLastModified().getTime()) - .build()))); - } - - @Test - public void matchGlobWithS3Options() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/foo/bar*baz"); - - ListObjectsV2Request firstRequest = - new ListObjectsV2Request() - .withBucketName(path.getBucket()) - .withPrefix(path.getKeyNonWildcardPrefix()) - .withContinuationToken(null); - - // Expected to be returned; prefix and wildcard/regex match - S3ObjectSummary firstMatch = new S3ObjectSummary(); - firstMatch.setBucketName(path.getBucket()); - firstMatch.setKey("foo/bar0baz"); - firstMatch.setSize(100); - firstMatch.setLastModified(new Date(1540000000001L)); - - // Expected to not be returned; prefix matches, but substring after wildcard does not - S3ObjectSummary secondMatch = new S3ObjectSummary(); - secondMatch.setBucketName(path.getBucket()); - secondMatch.setKey("foo/bar1qux"); - secondMatch.setSize(200); - secondMatch.setLastModified(new Date(1540000000002L)); - - // Expected first request returns continuation token - ListObjectsV2Result firstResult = new ListObjectsV2Result(); - firstResult.setNextContinuationToken("token"); - firstResult.getObjectSummaries().add(firstMatch); - firstResult.getObjectSummaries().add(secondMatch); - when(s3FileSystem - .getAmazonS3Client() - .listObjectsV2(argThat(new ListObjectsV2RequestArgumentMatches(firstRequest)))) - .thenReturn(firstResult); - - // Expect second request with continuation token - ListObjectsV2Request secondRequest = - new ListObjectsV2Request() - .withBucketName(path.getBucket()) - .withPrefix(path.getKeyNonWildcardPrefix()) - .withContinuationToken("token"); - - // Expected to be returned; prefix and wildcard/regex match - S3ObjectSummary thirdMatch = new S3ObjectSummary(); - thirdMatch.setBucketName(path.getBucket()); - thirdMatch.setKey("foo/bar2baz"); - thirdMatch.setSize(300); - thirdMatch.setLastModified(new Date(1540000000003L)); - - // Expected second request returns third prefix match and no continuation token - ListObjectsV2Result secondResult = new ListObjectsV2Result(); - secondResult.setNextContinuationToken(null); - secondResult.getObjectSummaries().add(thirdMatch); - when(s3FileSystem - .getAmazonS3Client() - .listObjectsV2(argThat(new ListObjectsV2RequestArgumentMatches(secondRequest)))) - .thenReturn(secondResult); - - // Expect object metadata queries for content encoding - ObjectMetadata metadata = new ObjectMetadata(); - metadata.setContentEncoding(""); - when(s3FileSystem.getAmazonS3Client().getObjectMetadata(anyObject())).thenReturn(metadata); - - assertThat( - s3FileSystem.matchGlobPaths(ImmutableList.of(path)).get(0), - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setIsReadSeekEfficient(true) - .setResourceId( - S3ResourceId.fromComponents( - "s3", firstMatch.getBucketName(), firstMatch.getKey())) - .setSizeBytes(firstMatch.getSize()) - .setLastModifiedMillis(firstMatch.getLastModified().getTime()) - .build(), - MatchResult.Metadata.builder() - .setIsReadSeekEfficient(true) - .setResourceId( - S3ResourceId.fromComponents( - "s3", thirdMatch.getBucketName(), thirdMatch.getKey())) - .setSizeBytes(thirdMatch.getSize()) - .setLastModifiedMillis(thirdMatch.getLastModified().getTime()) - .build()))); - } - - @Test - public void matchGlobWithSlashes() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("s3")); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/foo/bar\\baz*"); - - ListObjectsV2Request request = - new ListObjectsV2Request() - .withBucketName(path.getBucket()) - .withPrefix(path.getKeyNonWildcardPrefix()) - .withContinuationToken(null); - - // Expected to be returned; prefix and wildcard/regex match - S3ObjectSummary firstMatch = new S3ObjectSummary(); - firstMatch.setBucketName(path.getBucket()); - firstMatch.setKey("foo/bar\\baz0"); - firstMatch.setSize(100); - firstMatch.setLastModified(new Date(1540000000001L)); - - // Expected to not be returned; prefix matches, but substring after wildcard does not - S3ObjectSummary secondMatch = new S3ObjectSummary(); - secondMatch.setBucketName(path.getBucket()); - secondMatch.setKey("foo/bar/baz1"); - secondMatch.setSize(200); - secondMatch.setLastModified(new Date(1540000000002L)); - - // Expected first request returns continuation token - ListObjectsV2Result result = new ListObjectsV2Result(); - result.getObjectSummaries().add(firstMatch); - result.getObjectSummaries().add(secondMatch); - when(s3FileSystem - .getAmazonS3Client() - .listObjectsV2(argThat(new ListObjectsV2RequestArgumentMatches(request)))) - .thenReturn(result); - - // Expect object metadata queries for content encoding - ObjectMetadata metadata = new ObjectMetadata(); - metadata.setContentEncoding(""); - when(s3FileSystem.getAmazonS3Client().getObjectMetadata(anyObject())).thenReturn(metadata); - - assertThat( - s3FileSystem.matchGlobPaths(ImmutableList.of(path)).get(0), - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setIsReadSeekEfficient(true) - .setResourceId( - S3ResourceId.fromComponents( - "s3", firstMatch.getBucketName(), firstMatch.getKey())) - .setSizeBytes(firstMatch.getSize()) - .setLastModifiedMillis(firstMatch.getLastModified().getTime()) - .build()))); - } - - @Test - public void matchGlobWithSlashesWithS3Options() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/foo/bar\\baz*"); - - ListObjectsV2Request request = - new ListObjectsV2Request() - .withBucketName(path.getBucket()) - .withPrefix(path.getKeyNonWildcardPrefix()) - .withContinuationToken(null); - - // Expected to be returned; prefix and wildcard/regex match - S3ObjectSummary firstMatch = new S3ObjectSummary(); - firstMatch.setBucketName(path.getBucket()); - firstMatch.setKey("foo/bar\\baz0"); - firstMatch.setSize(100); - firstMatch.setLastModified(new Date(1540000000001L)); - - // Expected to not be returned; prefix matches, but substring after wildcard does not - S3ObjectSummary secondMatch = new S3ObjectSummary(); - secondMatch.setBucketName(path.getBucket()); - secondMatch.setKey("foo/bar/baz1"); - secondMatch.setSize(200); - secondMatch.setLastModified(new Date(1540000000002L)); - - // Expected first request returns continuation token - ListObjectsV2Result result = new ListObjectsV2Result(); - result.getObjectSummaries().add(firstMatch); - result.getObjectSummaries().add(secondMatch); - when(s3FileSystem - .getAmazonS3Client() - .listObjectsV2(argThat(new ListObjectsV2RequestArgumentMatches(request)))) - .thenReturn(result); - - // Expect object metadata queries for content encoding - ObjectMetadata metadata = new ObjectMetadata(); - metadata.setContentEncoding(""); - when(s3FileSystem.getAmazonS3Client().getObjectMetadata(anyObject())).thenReturn(metadata); - - assertThat( - s3FileSystem.matchGlobPaths(ImmutableList.of(path)).get(0), - MatchResultMatcher.create( - ImmutableList.of( - MatchResult.Metadata.builder() - .setIsReadSeekEfficient(true) - .setResourceId( - S3ResourceId.fromComponents( - "s3", firstMatch.getBucketName(), firstMatch.getKey())) - .setSizeBytes(firstMatch.getSize()) - .setLastModifiedMillis(firstMatch.getLastModified().getTime()) - .build()))); - } - - @Test - public void matchVariousInvokeThreadPool() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("s3")); - - AmazonS3Exception notFoundException = new AmazonS3Exception("mock exception"); - notFoundException.setStatusCode(404); - S3ResourceId pathNotExist = - S3ResourceId.fromUri("s3://testbucket/testdirectory/nonexistentfile"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest( - pathNotExist.getBucket(), pathNotExist.getKey()))))) - .thenThrow(notFoundException); - - AmazonS3Exception forbiddenException = new AmazonS3Exception("mock exception"); - forbiddenException.setStatusCode(403); - S3ResourceId pathForbidden = - S3ResourceId.fromUri("s3://testbucket/testdirectory/forbiddenfile"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest( - pathForbidden.getBucket(), pathForbidden.getKey()))))) - .thenThrow(forbiddenException); - - S3ResourceId pathExist = S3ResourceId.fromUri("s3://testbucket/testdirectory/filethatexists"); - ObjectMetadata s3ObjectMetadata = new ObjectMetadata(); - s3ObjectMetadata.setContentLength(100); - s3ObjectMetadata.setLastModified(new Date(1540000000000L)); - s3ObjectMetadata.setContentEncoding("not-gzip"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(pathExist.getBucket(), pathExist.getKey()))))) - .thenReturn(s3ObjectMetadata); - - S3ResourceId pathGlob = S3ResourceId.fromUri("s3://testbucket/path/part*"); - - S3ObjectSummary foundListObject = new S3ObjectSummary(); - foundListObject.setBucketName(pathGlob.getBucket()); - foundListObject.setKey("path/part-0"); - foundListObject.setSize(200); - foundListObject.setLastModified(new Date(1541000000000L)); - - ListObjectsV2Result listObjectsResult = new ListObjectsV2Result(); - listObjectsResult.setNextContinuationToken(null); - listObjectsResult.getObjectSummaries().add(foundListObject); - when(s3FileSystem.getAmazonS3Client().listObjectsV2(notNull(ListObjectsV2Request.class))) - .thenReturn(listObjectsResult); - - ObjectMetadata metadata = new ObjectMetadata(); - metadata.setContentEncoding(""); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(pathGlob.getBucket(), "path/part-0"))))) - .thenReturn(metadata); - - assertThat( - s3FileSystem.match( - ImmutableList.of( - pathNotExist.toString(), - pathForbidden.toString(), - pathExist.toString(), - pathGlob.toString())), - contains( - MatchResultMatcher.create(MatchResult.Status.NOT_FOUND, new FileNotFoundException()), - MatchResultMatcher.create( - MatchResult.Status.ERROR, new IOException(forbiddenException)), - MatchResultMatcher.create(100, 1540000000000L, pathExist, true), - MatchResultMatcher.create( - 200, - 1541000000000L, - S3ResourceId.fromComponents("s3", pathGlob.getBucket(), foundListObject.getKey()), - true))); - } - - @Test - public void matchVariousInvokeThreadPoolWithS3Options() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options()); - - AmazonS3Exception notFoundException = new AmazonS3Exception("mock exception"); - notFoundException.setStatusCode(404); - S3ResourceId pathNotExist = - S3ResourceId.fromUri("s3://testbucket/testdirectory/nonexistentfile"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest( - pathNotExist.getBucket(), pathNotExist.getKey()))))) - .thenThrow(notFoundException); - - AmazonS3Exception forbiddenException = new AmazonS3Exception("mock exception"); - forbiddenException.setStatusCode(403); - S3ResourceId pathForbidden = - S3ResourceId.fromUri("s3://testbucket/testdirectory/forbiddenfile"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest( - pathForbidden.getBucket(), pathForbidden.getKey()))))) - .thenThrow(forbiddenException); - - S3ResourceId pathExist = S3ResourceId.fromUri("s3://testbucket/testdirectory/filethatexists"); - ObjectMetadata s3ObjectMetadata = new ObjectMetadata(); - s3ObjectMetadata.setContentLength(100); - s3ObjectMetadata.setLastModified(new Date(1540000000000L)); - s3ObjectMetadata.setContentEncoding("not-gzip"); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(pathExist.getBucket(), pathExist.getKey()))))) - .thenReturn(s3ObjectMetadata); - - S3ResourceId pathGlob = S3ResourceId.fromUri("s3://testbucket/path/part*"); - - S3ObjectSummary foundListObject = new S3ObjectSummary(); - foundListObject.setBucketName(pathGlob.getBucket()); - foundListObject.setKey("path/part-0"); - foundListObject.setSize(200); - foundListObject.setLastModified(new Date(1541000000000L)); - - ListObjectsV2Result listObjectsResult = new ListObjectsV2Result(); - listObjectsResult.setNextContinuationToken(null); - listObjectsResult.getObjectSummaries().add(foundListObject); - when(s3FileSystem.getAmazonS3Client().listObjectsV2(notNull(ListObjectsV2Request.class))) - .thenReturn(listObjectsResult); - - ObjectMetadata metadata = new ObjectMetadata(); - metadata.setContentEncoding(""); - when(s3FileSystem - .getAmazonS3Client() - .getObjectMetadata( - argThat( - new GetObjectMetadataRequestMatcher( - new GetObjectMetadataRequest(pathGlob.getBucket(), "path/part-0"))))) - .thenReturn(metadata); - - assertThat( - s3FileSystem.match( - ImmutableList.of( - pathNotExist.toString(), - pathForbidden.toString(), - pathExist.toString(), - pathGlob.toString())), - contains( - MatchResultMatcher.create(MatchResult.Status.NOT_FOUND, new FileNotFoundException()), - MatchResultMatcher.create( - MatchResult.Status.ERROR, new IOException(forbiddenException)), - MatchResultMatcher.create(100, 1540000000000L, pathExist, true), - MatchResultMatcher.create( - 200, - 1541000000000L, - S3ResourceId.fromComponents("s3", pathGlob.getBucket(), foundListObject.getKey()), - true))); - } - - @Test - public void testWriteAndRead() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("s3"), client); - - client.createBucket("testbucket"); - - byte[] writtenArray = new byte[] {0}; - ByteBuffer bb = ByteBuffer.allocate(writtenArray.length); - bb.put(writtenArray); - - // First create an object and write data to it - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/foo/bar.txt"); - WritableByteChannel writableByteChannel = - s3FileSystem.create( - path, - CreateOptions.StandardCreateOptions.builder().setMimeType("application/text").build()); - writableByteChannel.write(bb); - writableByteChannel.close(); - - // Now read the same object - ByteBuffer bb2 = ByteBuffer.allocate(writtenArray.length); - ReadableByteChannel open = s3FileSystem.open(path); - open.read(bb2); - - // And compare the content with the one that was written - byte[] readArray = bb2.array(); - assertArrayEquals(readArray, writtenArray); - open.close(); - } - - @Test - public void testWriteAndReadWithS3Options() throws IOException { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Options(), client); - - client.createBucket("testbucket"); - - byte[] writtenArray = new byte[] {0}; - ByteBuffer bb = ByteBuffer.allocate(writtenArray.length); - bb.put(writtenArray); - - // First create an object and write data to it - S3ResourceId path = S3ResourceId.fromUri("s3://testbucket/foo/bar.txt"); - WritableByteChannel writableByteChannel = - s3FileSystem.create( - path, - CreateOptions.StandardCreateOptions.builder().setMimeType("application/text").build()); - writableByteChannel.write(bb); - writableByteChannel.close(); - - // Now read the same object - ByteBuffer bb2 = ByteBuffer.allocate(writtenArray.length); - ReadableByteChannel open = s3FileSystem.open(path); - open.read(bb2); - - // And compare the content with the one that was written - byte[] readArray = bb2.array(); - assertArrayEquals(readArray, writtenArray); - open.close(); - } - - @Test - public void testReportLineageOnBucket() { - verifyLineage("s3://testbucket", ImmutableList.of("testbucket")); - verifyLineage("s3://testbucket/", ImmutableList.of("testbucket")); - verifyLineage("s3://testbucket/foo/bar.txt", ImmutableList.of("testbucket", "foo/bar.txt")); - } - - private void verifyLineage(String uri, List expected) { - S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("mys3"), client); - S3ResourceId path = S3ResourceId.fromUri(uri); - Lineage mockLineage = mock(Lineage.class); - s3FileSystem.reportLineage(path, mockLineage); - verify(mockLineage, times(1)).add("s3", expected); - } - - /** A mockito argument matcher to implement equality on GetObjectMetadataRequest. */ - private static class GetObjectMetadataRequestMatcher - implements ArgumentMatcher { - private final GetObjectMetadataRequest expected; - - GetObjectMetadataRequestMatcher(GetObjectMetadataRequest expected) { - this.expected = expected; - } - - @Override - public boolean matches(GetObjectMetadataRequest obj) { - if (!(obj instanceof GetObjectMetadataRequest)) { - return false; - } - GetObjectMetadataRequest actual = (GetObjectMetadataRequest) obj; - return actual.getBucketName().equals(expected.getBucketName()) - && actual.getKey().equals(expected.getKey()); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3ResourceIdTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3ResourceIdTest.java deleted file mode 100644 index dd759cb63dbd..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3ResourceIdTest.java +++ /dev/null @@ -1,348 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY; -import static org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions.RESOLVE_FILE; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.util.Arrays; -import java.util.Base64; -import java.util.Calendar; -import java.util.Date; -import java.util.List; -import org.apache.beam.sdk.io.FileSystems; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; -import org.apache.beam.sdk.io.fs.ResourceId; -import org.apache.beam.sdk.io.fs.ResourceIdTester; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests {@link S3ResourceId}. */ -@RunWith(JUnit4.class) -public class S3ResourceIdTest { - - @Rule public ExpectedException thrown = ExpectedException.none(); - - static final class TestCase { - - final String baseUri; - final String relativePath; - final StandardResolveOptions resolveOptions; - final String expectedResult; - - TestCase( - String baseUri, - String relativePath, - StandardResolveOptions resolveOptions, - String expectedResult) { - this.baseUri = baseUri; - this.relativePath = relativePath; - this.resolveOptions = resolveOptions; - this.expectedResult = expectedResult; - } - } - - // Each test case is an expected URL, then the components used to build it. - // Empty components result in a double slash. - private static final List PATH_TEST_CASES = - Arrays.asList( - new TestCase("s3://bucket/", "", RESOLVE_DIRECTORY, "s3://bucket/"), - new TestCase("s3://bucket", "", RESOLVE_DIRECTORY, "s3://bucket/"), - new TestCase("s3://bucket", "path/to/dir", RESOLVE_DIRECTORY, "s3://bucket/path/to/dir/"), - new TestCase("s3://bucket", "path/to/object", RESOLVE_FILE, "s3://bucket/path/to/object"), - new TestCase( - "s3://bucket/path/to/dir/", "..", RESOLVE_DIRECTORY, "s3://bucket/path/to/")); - - private S3ResourceId deserializeFromB64(String base64) throws Exception { - ByteArrayInputStream b = new ByteArrayInputStream(Base64.getDecoder().decode(base64)); - try (ObjectInputStream s = new ObjectInputStream(b)) { - return (S3ResourceId) s.readObject(); - } - } - - private String serializeToB64(S3ResourceId r) throws Exception { - ByteArrayOutputStream b = new ByteArrayOutputStream(); - try (ObjectOutputStream s = new ObjectOutputStream(b)) { - s.writeObject(r); - } - return Base64.getEncoder().encodeToString(b.toByteArray()); - } - - @Test - public void testSerialization() throws Exception { - String r1Serialized = - "rO0ABXNyACpvcmcuYXBhY2hlLmJlYW0uc2RrLmlvLmF3cy5zMy5TM1Jlc291cmNlSWSN8nM8V4cVFwIABEwABmJ1Y2tldHQAEkxqYXZhL2xhbmcvU3RyaW5nO0wAA2tleXEAfgABTAAMbGFzdE1vZGlmaWVkdAAQTGphdmEvdXRpbC9EYXRlO0wABHNpemV0ABBMamF2YS9sYW5nL0xvbmc7eHB0AAZidWNrZXR0AAYvYS9iL2NwcA=="; - String r2Serialized = - "rO0ABXNyACpvcmcuYXBhY2hlLmJlYW0uc2RrLmlvLmF3cy5zMy5TM1Jlc291cmNlSWSN8nM8V4cVFwIABEwABmJ1Y2tldHQAEkxqYXZhL2xhbmcvU3RyaW5nO0wAA2tleXEAfgABTAAMbGFzdE1vZGlmaWVkdAAQTGphdmEvdXRpbC9EYXRlO0wABHNpemV0ABBMamF2YS9sYW5nL0xvbmc7eHB0AAxvdGhlci1idWNrZXR0AAYveC95L3pwc3IADmphdmEubGFuZy5Mb25nO4vkkMyPI98CAAFKAAV2YWx1ZXhyABBqYXZhLmxhbmcuTnVtYmVyhqyVHQuU4IsCAAB4cAAAAAAAAAB7"; - String r3Serialized = - "rO0ABXNyACpvcmcuYXBhY2hlLmJlYW0uc2RrLmlvLmF3cy5zMy5TM1Jlc291cmNlSWSN8nM8V4cVFwIABEwABmJ1Y2tldHQAEkxqYXZhL2xhbmcvU3RyaW5nO0wAA2tleXEAfgABTAAMbGFzdE1vZGlmaWVkdAAQTGphdmEvdXRpbC9EYXRlO0wABHNpemV0ABBMamF2YS9sYW5nL0xvbmc7eHB0AAx0aGlyZC1idWNrZXR0AAkvZm9vL2Jhci9zcgAOamF2YS51dGlsLkRhdGVoaoEBS1l0GQMAAHhwdwgAADgCgmXOAHhw"; - String r4Serialized = - "rO0ABXNyACpvcmcuYXBhY2hlLmJlYW0uc2RrLmlvLmF3cy5zMy5TM1Jlc291cmNlSWSN8nM8V4cVFwIABEwABmJ1Y2tldHQAEkxqYXZhL2xhbmcvU3RyaW5nO0wAA2tleXEAfgABTAAMbGFzdE1vZGlmaWVkdAAQTGphdmEvdXRpbC9EYXRlO0wABHNpemV0ABBMamF2YS9sYW5nL0xvbmc7eHB0AApiYXotYnVja2V0dAAGL2EvYi9jc3IADmphdmEudXRpbC5EYXRlaGqBAUtZdBkDAAB4cHcIAAA33gSV5gB4c3IADmphdmEubGFuZy5Mb25nO4vkkMyPI98CAAFKAAV2YWx1ZXhyABBqYXZhLmxhbmcuTnVtYmVyhqyVHQuU4IsCAAB4cAAAAAAAAAAq"; - - S3ResourceId r1 = S3ResourceId.fromComponents("s3", "bucket", "a/b/c"); - S3ResourceId r2 = S3ResourceId.fromComponents("s3", "other-bucket", "x/y/z").withSize(123); - S3ResourceId r3 = - S3ResourceId.fromComponents("s3", "third-bucket", "foo/bar/") - .withLastModified(new Date(121, Calendar.JULY, 3)); - S3ResourceId r4 = - S3ResourceId.fromComponents("s3", "baz-bucket", "a/b/c") - .withSize(42) - .withLastModified(new Date(116, Calendar.JULY, 15)); - S3ResourceId r5 = S3ResourceId.fromComponents("other-scheme", "bucket", "a/b/c"); - S3ResourceId r6 = - S3ResourceId.fromComponents("other-scheme", "baz-bucket", "foo/bar/") - .withSize(42) - .withLastModified(new Date(116, Calendar.JULY, 5)); - - // S3ResourceIds serialized by previous versions should still deserialize. - assertEquals(r1, deserializeFromB64(r1Serialized)); - assertEquals(r2, deserializeFromB64(r2Serialized)); - assertEquals(r3, deserializeFromB64(r3Serialized)); - assertEquals(r4, deserializeFromB64(r4Serialized)); - - // Current resource IDs should round-trip properly through serialization. - assertEquals(r1, deserializeFromB64(serializeToB64(r1))); - assertEquals(r2, deserializeFromB64(serializeToB64(r2))); - assertEquals(r3, deserializeFromB64(serializeToB64(r3))); - assertEquals(r4, deserializeFromB64(serializeToB64(r4))); - assertEquals(r5, deserializeFromB64(serializeToB64(r5))); - assertEquals(r6, deserializeFromB64(serializeToB64(r6))); - } - - @Test - public void testResolve() { - for (TestCase testCase : PATH_TEST_CASES) { - ResourceId resourceId = S3ResourceId.fromUri(testCase.baseUri); - ResourceId resolved = resourceId.resolve(testCase.relativePath, testCase.resolveOptions); - assertEquals(testCase.expectedResult, resolved.toString()); - } - - // Tests for common s3 paths. - assertEquals( - S3ResourceId.fromUri("s3://bucket/tmp/aa"), - S3ResourceId.fromUri("s3://bucket/tmp/").resolve("aa", RESOLVE_FILE)); - assertEquals( - S3ResourceId.fromUri("s3://bucket/tmp/aa/bb/cc/"), - S3ResourceId.fromUri("s3://bucket/tmp/") - .resolve("aa", RESOLVE_DIRECTORY) - .resolve("bb", RESOLVE_DIRECTORY) - .resolve("cc", RESOLVE_DIRECTORY)); - - // Tests absolute path. - assertEquals( - S3ResourceId.fromUri("s3://bucket/tmp/aa"), - S3ResourceId.fromUri("s3://bucket/tmp/bb/").resolve("s3://bucket/tmp/aa", RESOLVE_FILE)); - - // Tests bucket with no ending '/'. - assertEquals( - S3ResourceId.fromUri("s3://my-bucket/tmp"), - S3ResourceId.fromUri("s3://my-bucket").resolve("tmp", RESOLVE_FILE)); - - // Tests path with unicode - assertEquals( - S3ResourceId.fromUri("s3://bucket/输出 目录/输出 文件01.txt"), - S3ResourceId.fromUri("s3://bucket/输出 目录/").resolve("输出 文件01.txt", RESOLVE_FILE)); - } - - @Test - public void testResolveInvalidInputs() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Cannot resolve a file with a directory path: [tmp/]"); - S3ResourceId.fromUri("s3://my_bucket/").resolve("tmp/", RESOLVE_FILE); - } - - @Test - public void testResolveInvalidNotDirectory() { - ResourceId tmpDir = S3ResourceId.fromUri("s3://my_bucket/").resolve("tmp dir", RESOLVE_FILE); - - thrown.expect(IllegalStateException.class); - thrown.expectMessage( - "Expected this resource to be a directory, but was [s3://my_bucket/tmp dir]"); - tmpDir.resolve("aa", RESOLVE_FILE); - } - - @Test - public void testS3ResolveWithFileBase() { - ResourceId resourceId = S3ResourceId.fromUri("s3://bucket/path/to/file"); - thrown.expect(IllegalStateException.class); - resourceId.resolve("child-path", RESOLVE_DIRECTORY); // resource is not a directory - } - - @Test - public void testResolveParentToFile() { - ResourceId resourceId = S3ResourceId.fromUri("s3://bucket/path/to/dir/"); - thrown.expect(IllegalArgumentException.class); - resourceId.resolve("..", RESOLVE_FILE); // '..' only resolves as dir, not as file - } - - @Test - public void testGetCurrentDirectory() { - // Tests s3 paths. - assertEquals( - S3ResourceId.fromUri("s3://my_bucket/tmp dir/"), - S3ResourceId.fromUri("s3://my_bucket/tmp dir/").getCurrentDirectory()); - - // Tests path with unicode. - assertEquals( - S3ResourceId.fromUri("s3://my_bucket/输出 目录/"), - S3ResourceId.fromUri("s3://my_bucket/输出 目录/文件01.txt").getCurrentDirectory()); - - // Tests bucket with no ending '/'. - assertEquals( - S3ResourceId.fromUri("s3://my_bucket/"), - S3ResourceId.fromUri("s3://my_bucket").getCurrentDirectory()); - assertEquals( - S3ResourceId.fromUri("s3://my_bucket/"), - S3ResourceId.fromUri("s3://my_bucket/not-directory").getCurrentDirectory()); - } - - @Test - public void testIsDirectory() { - assertTrue(S3ResourceId.fromUri("s3://my_bucket/tmp dir/").isDirectory()); - assertTrue(S3ResourceId.fromUri("s3://my_bucket/").isDirectory()); - assertTrue(S3ResourceId.fromUri("s3://my_bucket").isDirectory()); - assertFalse(S3ResourceId.fromUri("s3://my_bucket/file").isDirectory()); - } - - @Test - public void testInvalidPathNoBucket() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid S3 URI: [s3://]"); - S3ResourceId.fromUri("s3://"); - } - - @Test - public void testInvalidPathNoBucketAndSlash() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid S3 URI: [s3:///]"); - S3ResourceId.fromUri("s3:///"); - } - - @Test - public void testGetScheme() { - // Tests s3 paths. - assertEquals("s3", S3ResourceId.fromUri("s3://my_bucket/tmp dir/").getScheme()); - - // Tests bucket with no ending '/'. - assertEquals("s3", S3ResourceId.fromUri("s3://my_bucket").getScheme()); - } - - @Test - public void testGetFilename() { - assertNull(S3ResourceId.fromUri("s3://my_bucket/").getFilename()); - assertEquals("abc", S3ResourceId.fromUri("s3://my_bucket/abc").getFilename()); - assertEquals("abc", S3ResourceId.fromUri("s3://my_bucket/abc/").getFilename()); - assertEquals("def", S3ResourceId.fromUri("s3://my_bucket/abc/def").getFilename()); - assertEquals("def", S3ResourceId.fromUri("s3://my_bucket/abc/def/").getFilename()); - assertEquals("xyz.txt", S3ResourceId.fromUri("s3://my_bucket/abc/xyz.txt").getFilename()); - } - - @Test - public void testParentRelationship() { - S3ResourceId path = S3ResourceId.fromUri("s3://bucket/dir/subdir/object"); - assertEquals("bucket", path.getBucket()); - assertEquals("dir/subdir/object", path.getKey()); - - // s3://bucket/dir/ - path = S3ResourceId.fromUri("s3://bucket/dir/subdir/"); - S3ResourceId parent = (S3ResourceId) path.resolve("..", RESOLVE_DIRECTORY); - assertEquals("bucket", parent.getBucket()); - assertEquals("dir/", parent.getKey()); - assertNotEquals(path, parent); - assertTrue(path.getKey().startsWith(parent.getKey())); - assertFalse(parent.getKey().startsWith(path.getKey())); - - // s3://bucket/ - S3ResourceId grandParent = (S3ResourceId) parent.resolve("..", RESOLVE_DIRECTORY); - assertEquals("bucket", grandParent.getBucket()); - assertEquals("", grandParent.getKey()); - } - - @Test - public void testBucketParsing() { - S3ResourceId path = S3ResourceId.fromUri("s3://bucket"); - S3ResourceId path2 = S3ResourceId.fromUri("s3://bucket/"); - - assertEquals(path, path2); - assertEquals(path.toString(), path2.toString()); - } - - @Test - public void testS3ResourceIdToString() { - String filename = "s3://some-bucket/some/file.txt"; - S3ResourceId path = S3ResourceId.fromUri(filename); - assertEquals(filename, path.toString()); - - filename = "s3://some-bucket/some/"; - path = S3ResourceId.fromUri(filename); - assertEquals(filename, path.toString()); - - filename = "s3://some-bucket/"; - path = S3ResourceId.fromUri(filename); - assertEquals(filename, path.toString()); - } - - @Test - public void testEquals() { - S3ResourceId a = S3ResourceId.fromComponents("s3", "bucket", "a/b/c"); - S3ResourceId b = S3ResourceId.fromComponents("s3", "bucket", "a/b/c"); - assertEquals(a, b); - - b = S3ResourceId.fromComponents("s3", a.getBucket(), "a/b/c/"); - assertNotEquals(a, b); - - b = S3ResourceId.fromComponents("s3", a.getBucket(), "x/y/z"); - assertNotEquals(a, b); - - b = S3ResourceId.fromComponents("s3", "other-bucket", a.getKey()); - assertNotEquals(a, b); - assertNotEquals(b, a); - - b = S3ResourceId.fromComponents("other", "bucket", "a/b/c"); - assertNotEquals(a, b); - assertNotEquals(b, a); - } - - @Test - public void testInvalidBucket() { - thrown.expect(IllegalArgumentException.class); - S3ResourceId.fromComponents("s3", "invalid/", ""); - } - - @Test - public void testResourceIdTester() { - S3Options options = PipelineOptionsFactory.create().as(S3Options.class); - options.setAwsRegion("us-west-1"); - FileSystems.setDefaultPipelineOptions(options); - ResourceIdTester.runResourceIdBattery(S3ResourceId.fromUri("s3://bucket/foo/")); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3TestUtils.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3TestUtils.java deleted file mode 100644 index 3df2f10f9c82..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3TestUtils.java +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.ObjectMetadata; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import com.amazonaws.services.s3.model.SSECustomerKey; -import com.amazonaws.util.Base64; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.commons.codec.digest.DigestUtils; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.mockito.Mockito; - -/** Utils to test S3 filesystem. */ -class S3TestUtils { - private static S3FileSystemConfiguration.Builder configBuilder(String scheme) { - S3Options options = PipelineOptionsFactory.as(S3Options.class); - options.setAwsRegion("us-west-1"); - options.setS3UploadBufferSizeBytes(5_242_880); - return S3FileSystemConfiguration.fromS3Options(options).setScheme(scheme); - } - - static S3FileSystemConfiguration s3Config(String scheme) { - return configBuilder(scheme).build(); - } - - static S3Options s3Options() { - S3Options options = PipelineOptionsFactory.as(S3Options.class); - options.setAwsRegion("us-west-1"); - options.setS3UploadBufferSizeBytes(5_242_880); - return options; - } - - static S3Options s3OptionsWithCustomEndpointAndPathStyleAccessEnabled() { - S3Options options = PipelineOptionsFactory.as(S3Options.class); - options.setAwsServiceEndpoint("https://s3.custom.dns"); - options.setAwsRegion("no-matter"); - options.setS3UploadBufferSizeBytes(5_242_880); - options.setS3ClientFactoryClass(PathStyleAccessS3ClientBuilderFactory.class); - return options; - } - - static S3FileSystemConfiguration s3ConfigWithCustomEndpointAndPathStyleAccessEnabled( - String scheme) { - return S3FileSystemConfiguration.fromS3Options( - s3OptionsWithCustomEndpointAndPathStyleAccessEnabled()) - .setScheme(scheme) - .build(); - } - - static S3FileSystemConfiguration s3ConfigWithSSEAlgorithm(String scheme) { - return configBuilder(scheme) - .setSSEAlgorithm(ObjectMetadata.AES_256_SERVER_SIDE_ENCRYPTION) - .build(); - } - - static S3Options s3OptionsWithSSEAlgorithm() { - S3Options options = s3Options(); - options.setSSEAlgorithm(ObjectMetadata.AES_256_SERVER_SIDE_ENCRYPTION); - return options; - } - - static S3FileSystemConfiguration s3ConfigWithSSECustomerKey(String scheme) { - return configBuilder(scheme) - .setSSECustomerKey(new SSECustomerKey("86glyTlCNZgccSxW8JxMa6ZdjdK3N141glAysPUZ3AA=")) - .build(); - } - - static S3Options s3OptionsWithSSECustomerKey() { - S3Options options = s3Options(); - options.setSSECustomerKey(new SSECustomerKey("86glyTlCNZgccSxW8JxMa6ZdjdK3N141glAysPUZ3AA=")); - return options; - } - - static S3FileSystemConfiguration s3ConfigWithSSEAwsKeyManagementParams(String scheme) { - String awsKmsKeyId = - "arn:aws:kms:eu-west-1:123456789012:key/dc123456-7890-ABCD-EF01-234567890ABC"; - SSEAwsKeyManagementParams sseAwsKeyManagementParams = - new SSEAwsKeyManagementParams(awsKmsKeyId); - return configBuilder(scheme) - .setSSEAwsKeyManagementParams(sseAwsKeyManagementParams) - .setBucketKeyEnabled(true) - .build(); - } - - static S3Options s3OptionsWithSSEAwsKeyManagementParams() { - S3Options options = s3Options(); - String awsKmsKeyId = - "arn:aws:kms:eu-west-1:123456789012:key/dc123456-7890-ABCD-EF01-234567890ABC"; - SSEAwsKeyManagementParams sseAwsKeyManagementParams = - new SSEAwsKeyManagementParams(awsKmsKeyId); - options.setSSEAwsKeyManagementParams(sseAwsKeyManagementParams); - options.setBucketKeyEnabled(true); - return options; - } - - static S3FileSystemConfiguration s3ConfigWithMultipleSSEOptions(String scheme) { - return s3ConfigWithSSEAwsKeyManagementParams(scheme) - .toBuilder() - .setSSECustomerKey(new SSECustomerKey("86glyTlCNZgccSxW8JxMa6ZdjdK3N141glAysPUZ3AA=")) - .build(); - } - - static S3Options s3OptionsWithMultipleSSEOptions() { - S3Options options = s3OptionsWithSSEAwsKeyManagementParams(); - options.setSSECustomerKey(new SSECustomerKey("86glyTlCNZgccSxW8JxMa6ZdjdK3N141glAysPUZ3AA=")); - return options; - } - - static S3FileSystem buildMockedS3FileSystem(S3FileSystemConfiguration config) { - return buildMockedS3FileSystem(config, Mockito.mock(AmazonS3.class)); - } - - static S3FileSystem buildMockedS3FileSystem(S3Options options) { - return buildMockedS3FileSystem(options, Mockito.mock(AmazonS3.class)); - } - - static S3FileSystem buildMockedS3FileSystem(S3FileSystemConfiguration config, AmazonS3 client) { - S3FileSystem s3FileSystem = new S3FileSystem(config); - s3FileSystem.setAmazonS3Client(client); - return s3FileSystem; - } - - static S3FileSystem buildMockedS3FileSystem(S3Options options, AmazonS3 client) { - S3FileSystem s3FileSystem = new S3FileSystem(options); - s3FileSystem.setAmazonS3Client(client); - return s3FileSystem; - } - - static @Nullable String toMd5(SSECustomerKey key) { - if (key != null && key.getKey() != null) { - return Base64.encodeAsString(DigestUtils.md5(Base64.decode(key.getKey()))); - } - return null; - } - - static @Nullable String getSSECustomerKeyMd5(S3FileSystemConfiguration config) { - return toMd5(config.getSSECustomerKey()); - } - - static @Nullable String getSSECustomerKeyMd5(S3Options options) { - return toMd5(options.getSSECustomerKey()); - } - - private static class PathStyleAccessS3ClientBuilderFactory extends DefaultS3ClientBuilderFactory { - @Override - public AmazonS3ClientBuilder createBuilder(S3Options s3Options) { - return super.createBuilder(s3Options).withPathStyleAccessEnabled(true); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3WritableByteChannelTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3WritableByteChannelTest.java deleted file mode 100644 index cb577d860322..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3WritableByteChannelTest.java +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.s3; - -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3Config; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3ConfigWithMultipleSSEOptions; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3ConfigWithSSEAlgorithm; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3ConfigWithSSEAwsKeyManagementParams; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3ConfigWithSSECustomerKey; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3Options; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3OptionsWithMultipleSSEOptions; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3OptionsWithSSEAlgorithm; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3OptionsWithSSEAwsKeyManagementParams; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.s3OptionsWithSSECustomerKey; -import static org.apache.beam.sdk.io.aws.s3.S3TestUtils.toMd5; -import static org.apache.beam.sdk.io.aws.s3.S3WritableByteChannel.atMostOne; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Matchers.notNull; -import static org.mockito.Mockito.RETURNS_SMART_NULLS; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.withSettings; - -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest; -import com.amazonaws.services.s3.model.CompleteMultipartUploadResult; -import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest; -import com.amazonaws.services.s3.model.InitiateMultipartUploadResult; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; -import com.amazonaws.services.s3.model.UploadPartRequest; -import com.amazonaws.services.s3.model.UploadPartResult; -import java.io.IOException; -import java.nio.ByteBuffer; -import org.apache.beam.sdk.io.aws.options.S3Options; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests {@link S3WritableByteChannel}. */ -@RunWith(JUnit4.class) -public class S3WritableByteChannelTest { - @Rule public ExpectedException expected = ExpectedException.none(); - - @Test - public void write() throws IOException { - writeFromConfig(s3Config("s3"), false); - writeFromConfig(s3Config("s3"), true); - writeFromConfig(s3ConfigWithSSEAlgorithm("s3"), false); - writeFromConfig(s3ConfigWithSSECustomerKey("s3"), false); - writeFromConfig(s3ConfigWithSSEAwsKeyManagementParams("s3"), false); - expected.expect(IllegalArgumentException.class); - writeFromConfig(s3ConfigWithMultipleSSEOptions("s3"), false); - } - - @Test - public void writeWithS3Options() throws IOException { - writeFromOptions(s3Options(), false); - writeFromOptions(s3Options(), true); - writeFromOptions(s3OptionsWithSSEAlgorithm(), false); - writeFromOptions(s3OptionsWithSSECustomerKey(), false); - writeFromOptions(s3OptionsWithSSEAwsKeyManagementParams(), false); - expected.expect(IllegalArgumentException.class); - writeFromOptions(s3OptionsWithMultipleSSEOptions(), false); - } - - @FunctionalInterface - public interface Supplier { - S3WritableByteChannel get() throws IOException; - } - - private void writeFromOptions(S3Options options, boolean writeReadOnlyBuffer) throws IOException { - AmazonS3 mockAmazonS3 = mock(AmazonS3.class, withSettings().defaultAnswer(RETURNS_SMART_NULLS)); - S3ResourceId path = S3ResourceId.fromUri("s3://bucket/dir/file"); - Supplier channel = - () -> - new S3WritableByteChannel( - mockAmazonS3, - path, - "text/plain", - S3FileSystemConfiguration.fromS3Options(options).build()); - write( - mockAmazonS3, - channel, - path, - options.getSSEAlgorithm(), - toMd5(options.getSSECustomerKey()), - options.getSSEAwsKeyManagementParams(), - options.getS3UploadBufferSizeBytes(), - options.getBucketKeyEnabled(), - writeReadOnlyBuffer); - } - - private void writeFromConfig(S3FileSystemConfiguration config, boolean writeReadOnlyBuffer) - throws IOException { - AmazonS3 mockAmazonS3 = mock(AmazonS3.class, withSettings().defaultAnswer(RETURNS_SMART_NULLS)); - S3ResourceId path = S3ResourceId.fromUri("s3://bucket/dir/file"); - Supplier channel = () -> new S3WritableByteChannel(mockAmazonS3, path, "text/plain", config); - write( - mockAmazonS3, - channel, - path, - config.getSSEAlgorithm(), - toMd5(config.getSSECustomerKey()), - config.getSSEAwsKeyManagementParams(), - config.getS3UploadBufferSizeBytes(), - config.getBucketKeyEnabled(), - writeReadOnlyBuffer); - } - - private void write( - AmazonS3 mockAmazonS3, - Supplier channelSupplier, - S3ResourceId path, - String sseAlgorithm, - String sseCustomerKeyMd5, - SSEAwsKeyManagementParams sseAwsKeyManagementParams, - long s3UploadBufferSizeBytes, - boolean bucketKeyEnabled, - boolean writeReadOnlyBuffer) - throws IOException { - InitiateMultipartUploadResult initiateMultipartUploadResult = - new InitiateMultipartUploadResult(); - initiateMultipartUploadResult.setUploadId("upload-id"); - if (sseAlgorithm != null) { - initiateMultipartUploadResult.setSSEAlgorithm(sseAlgorithm); - } - if (sseCustomerKeyMd5 != null) { - initiateMultipartUploadResult.setSSECustomerKeyMd5(sseCustomerKeyMd5); - } - if (sseAwsKeyManagementParams != null) { - sseAlgorithm = "aws:kms"; - initiateMultipartUploadResult.setSSEAlgorithm(sseAlgorithm); - } - initiateMultipartUploadResult.setBucketKeyEnabled(bucketKeyEnabled); - doReturn(initiateMultipartUploadResult) - .when(mockAmazonS3) - .initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)); - - InitiateMultipartUploadResult mockInitiateMultipartUploadResult = - mockAmazonS3.initiateMultipartUpload( - new InitiateMultipartUploadRequest(path.getBucket(), path.getKey())); - assertEquals(sseAlgorithm, mockInitiateMultipartUploadResult.getSSEAlgorithm()); - assertEquals(bucketKeyEnabled, mockInitiateMultipartUploadResult.getBucketKeyEnabled()); - assertEquals(sseCustomerKeyMd5, mockInitiateMultipartUploadResult.getSSECustomerKeyMd5()); - - UploadPartResult result = new UploadPartResult(); - result.setETag("etag"); - if (sseCustomerKeyMd5 != null) { - result.setSSECustomerKeyMd5(sseCustomerKeyMd5); - } - doReturn(result).when(mockAmazonS3).uploadPart(any(UploadPartRequest.class)); - - UploadPartResult mockUploadPartResult = mockAmazonS3.uploadPart(new UploadPartRequest()); - assertEquals(sseCustomerKeyMd5, mockUploadPartResult.getSSECustomerKeyMd5()); - - int contentSize = 34_078_720; - ByteBuffer uploadContent = ByteBuffer.allocate((int) (contentSize * 2.5)); - for (int i = 0; i < contentSize; i++) { - uploadContent.put((byte) 0xff); - } - uploadContent.flip(); - - S3WritableByteChannel channel = channelSupplier.get(); - int uploadedSize = - channel.write(writeReadOnlyBuffer ? uploadContent.asReadOnlyBuffer() : uploadContent); - assertEquals(contentSize, uploadedSize); - - CompleteMultipartUploadResult completeMultipartUploadResult = - new CompleteMultipartUploadResult(); - doReturn(completeMultipartUploadResult) - .when(mockAmazonS3) - .completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); - - channel.close(); - - verify(mockAmazonS3, times(2)) - .initiateMultipartUpload(notNull(InitiateMultipartUploadRequest.class)); - int partQuantity = (int) Math.ceil((double) contentSize / s3UploadBufferSizeBytes) + 1; - verify(mockAmazonS3, times(partQuantity)).uploadPart(notNull(UploadPartRequest.class)); - verify(mockAmazonS3, times(1)) - .completeMultipartUpload(notNull(CompleteMultipartUploadRequest.class)); - verifyNoMoreInteractions(mockAmazonS3); - } - - @Test - public void testAtMostOne() { - assertTrue(atMostOne(true)); - assertTrue(atMostOne(false)); - assertFalse(atMostOne(true, true)); - assertTrue(atMostOne(true, false)); - assertTrue(atMostOne(false, true)); - assertTrue(atMostOne(false, false)); - assertFalse(atMostOne(true, true, true)); - assertFalse(atMostOne(true, true, false)); - assertFalse(atMostOne(true, false, true)); - assertTrue(atMostOne(true, false, false)); - assertFalse(atMostOne(false, true, true)); - assertTrue(atMostOne(false, true, false)); - assertTrue(atMostOne(false, false, true)); - assertTrue(atMostOne(false, false, false)); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/PublishResultCodersTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/PublishResultCodersTest.java deleted file mode 100644 index e8f8643cbbd4..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/PublishResultCodersTest.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sns; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; - -import com.amazonaws.ResponseMetadata; -import com.amazonaws.http.HttpResponse; -import com.amazonaws.http.SdkHttpMetadata; -import com.amazonaws.services.sns.model.PublishResult; -import java.util.UUID; -import org.apache.beam.sdk.testing.CoderProperties; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.junit.Test; - -/** Tests for PublishResult coders. */ -public class PublishResultCodersTest { - - @Test - public void testDefaultPublishResultDecodeEncodeEquals() throws Exception { - CoderProperties.coderDecodeEncodeEqual( - PublishResultCoders.defaultPublishResult(), - new PublishResult().withMessageId(UUID.randomUUID().toString())); - } - - @Test - public void testFullPublishResultWithoutHeadersDecodeEncodeEquals() throws Exception { - CoderProperties.coderDecodeEncodeEqual( - PublishResultCoders.fullPublishResultWithoutHeaders(), - new PublishResult().withMessageId(UUID.randomUUID().toString())); - - PublishResult value = buildFullPublishResult(); - PublishResult clone = - CoderUtils.clone(PublishResultCoders.fullPublishResultWithoutHeaders(), value); - assertThat( - clone.getSdkResponseMetadata().getRequestId(), - equalTo(value.getSdkResponseMetadata().getRequestId())); - assertThat( - clone.getSdkHttpMetadata().getHttpStatusCode(), - equalTo(value.getSdkHttpMetadata().getHttpStatusCode())); - assertThat(clone.getSdkHttpMetadata().getHttpHeaders().isEmpty(), equalTo(true)); - } - - @Test - public void testFullPublishResultIncludingHeadersDecodeEncodeEquals() throws Exception { - CoderProperties.coderDecodeEncodeEqual( - PublishResultCoders.fullPublishResult(), - new PublishResult().withMessageId(UUID.randomUUID().toString())); - - PublishResult value = buildFullPublishResult(); - PublishResult clone = CoderUtils.clone(PublishResultCoders.fullPublishResult(), value); - assertThat( - clone.getSdkResponseMetadata().getRequestId(), - equalTo(value.getSdkResponseMetadata().getRequestId())); - assertThat( - clone.getSdkHttpMetadata().getHttpStatusCode(), - equalTo(value.getSdkHttpMetadata().getHttpStatusCode())); - assertThat( - clone.getSdkHttpMetadata().getHttpHeaders(), - equalTo(value.getSdkHttpMetadata().getHttpHeaders())); - } - - private PublishResult buildFullPublishResult() { - PublishResult publishResult = new PublishResult().withMessageId(UUID.randomUUID().toString()); - publishResult.setSdkResponseMetadata( - new ResponseMetadata( - ImmutableMap.of(ResponseMetadata.AWS_REQUEST_ID, UUID.randomUUID().toString()))); - HttpResponse httpResponse = new HttpResponse(null, null); - httpResponse.setStatusCode(200); - httpResponse.addHeader("Content-Type", "application/json"); - publishResult.setSdkHttpMetadata(SdkHttpMetadata.from(httpResponse)); - return publishResult; - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOIT.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOIT.java deleted file mode 100644 index c19aada628fa..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOIT.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sns; - -import static org.apache.beam.sdk.io.common.IOITHelper.executeWithRetry; -import static org.apache.beam.sdk.io.common.TestRow.getExpectedHashForRowCount; -import static org.apache.beam.sdk.values.TypeDescriptors.strings; -import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SNS; -import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SQS; - -import com.amazonaws.regions.Regions; -import com.amazonaws.services.sns.AmazonSNS; -import com.amazonaws.services.sns.AmazonSNSClientBuilder; -import com.amazonaws.services.sns.model.PublishRequest; -import com.amazonaws.services.sqs.AmazonSQS; -import com.amazonaws.services.sqs.AmazonSQSClientBuilder; -import com.amazonaws.services.sqs.model.Message; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import java.io.Serializable; -import org.apache.beam.sdk.io.GenerateSequence; -import org.apache.beam.sdk.io.aws.ITEnvironment; -import org.apache.beam.sdk.io.aws.sqs.SqsIO; -import org.apache.beam.sdk.io.common.HashingFn; -import org.apache.beam.sdk.io.common.TestRow; -import org.apache.beam.sdk.io.common.TestRow.DeterministicallyConstructTestRowFn; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.junit.ClassRule; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.rules.Timeout; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.testcontainers.containers.localstack.LocalStackContainer.Service; - -@RunWith(JUnit4.class) -public class SnsIOIT { - public interface ITOptions extends ITEnvironment.ITOptions {} - - private static final ObjectMapper MAPPER = new ObjectMapper(); - private static final TypeDescriptor publishRequests = - TypeDescriptor.of(PublishRequest.class); - - @ClassRule - public static ITEnvironment env = - new ITEnvironment<>(new Service[] {SQS, SNS}, ITOptions.class, "SQS_PROVIDER=elasticmq"); - - @Rule public Timeout globalTimeout = Timeout.seconds(600); - - @Rule public TestPipeline pipelineWrite = env.createTestPipeline(); - @Rule public TestPipeline pipelineRead = env.createTestPipeline(); - @Rule public AwsResources resources = new AwsResources(); - - @Test - public void testWriteThenRead() { - ITOptions opts = env.options(); - int rows = opts.getNumberOfRows(); - - // Write test dataset to SNS - - pipelineWrite - .apply("Generate Sequence", GenerateSequence.from(0).to(rows)) - .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn())) - .apply("SNS request", MapElements.into(publishRequests).via(resources::publishRequest)) - .apply( - "Write to SNS", - SnsIO.write() - .withTopicName(resources.snsTopic) - .withResultOutputTag(new TupleTag<>()) - .withAWSClientsProvider( - opts.getAwsCredentialsProvider().getCredentials().getAWSAccessKeyId(), - opts.getAwsCredentialsProvider().getCredentials().getAWSSecretKey(), - Regions.fromName(opts.getAwsRegion()), - opts.getAwsServiceEndpoint())); - - // Read test dataset from SQS. - PCollection output = - pipelineRead - .apply( - "Read from SQS", - SqsIO.read().withQueueUrl(resources.sqsQueue).withMaxNumRecords(rows)) - .apply("Extract message", MapElements.into(strings()).via(SnsIOIT::extractMessage)); - - PAssert.thatSingleton(output.apply("Count All", Count.globally())).isEqualTo((long) rows); - - PAssert.that(output.apply(Combine.globally(new HashingFn()).withoutDefaults())) - .containsInAnyOrder(getExpectedHashForRowCount(rows)); - - pipelineWrite.run(); - pipelineRead.run(); - } - - private static String extractMessage(Message msg) { - try { - return MAPPER.readTree(msg.getBody()).get("Message").asText(); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - private static class AwsResources extends ExternalResource implements Serializable { - private transient AmazonSQS sqs = env.buildClient(AmazonSQSClientBuilder.standard()); - private transient AmazonSNS sns = env.buildClient(AmazonSNSClientBuilder.standard()); - - private String sqsQueue; - private String snsTopic; - private String sns2Sqs; - - PublishRequest publishRequest(TestRow r) { - return new PublishRequest(snsTopic, r.name()); - } - - @Override - protected void before() throws Throwable { - snsTopic = sns.createTopic("beam-snsio-it").getTopicArn(); - // add SQS subscription so we can read the messages again - sqsQueue = sqs.createQueue("beam-snsio-it").getQueueUrl(); - sns2Sqs = sns.subscribe(snsTopic, "sqs", sqsQueue).getSubscriptionArn(); - } - - @Override - protected void after() { - try { - executeWithRetry(() -> sns.unsubscribe(sns2Sqs)); - executeWithRetry(() -> sns.deleteTopic(snsTopic)); - executeWithRetry(() -> sqs.deleteQueue(sqsQueue)); - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - sns.shutdown(); - sqs.shutdown(); - } - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java deleted file mode 100644 index f86c0851a01c..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sns; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.joda.time.Duration.millis; -import static org.joda.time.Duration.standardSeconds; - -import com.amazonaws.http.SdkHttpMetadata; -import com.amazonaws.services.sns.AmazonSNS; -import com.amazonaws.services.sns.model.GetTopicAttributesResult; -import com.amazonaws.services.sns.model.InternalErrorException; -import com.amazonaws.services.sns.model.PublishRequest; -import com.amazonaws.services.sns.model.PublishResult; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.io.Serializable; -import java.util.HashMap; -import java.util.UUID; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.testing.ExpectedLogs; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TypeDescriptors; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mockito; -import org.slf4j.helpers.MessageFormatter; - -/** Tests to verify writes to Sns. */ -@RunWith(JUnit4.class) -public class SnsIOTest implements Serializable { - - private static final String topicName = "arn:aws:sns:us-west-2:5880:topic-FMFEHJ47NRFO"; - - @Rule public TestPipeline p = TestPipeline.create(); - - @Rule - public final transient ExpectedLogs snsWriterFnLogs = - ExpectedLogs.none(SnsIO.Write.SnsWriterFn.class); - - private static PublishRequest createSampleMessage(String message) { - return new PublishRequest().withTopicArn(topicName).withMessage(message); - } - - private static class Provider implements AwsClientsProvider { - - private static AmazonSNS publisher; - - public Provider(AmazonSNS pub) { - publisher = pub; - } - - @Override - public AmazonSNS createSnsPublisher() { - return publisher; - } - } - - @Test - public void testDataWritesToSNS() { - final PublishRequest request1 = createSampleMessage("my_first_message"); - final PublishRequest request2 = createSampleMessage("my_second_message"); - - final TupleTag results = new TupleTag<>(); - final AmazonSNS amazonSnsSuccess = getAmazonSnsMockSuccess(); - - final PCollectionTuple snsWrites = - p.apply(Create.of(request1, request2)) - .apply( - SnsIO.write() - .withTopicName(topicName) - .withRetryConfiguration( - SnsIO.RetryConfiguration.create( - 5, org.joda.time.Duration.standardMinutes(1))) - .withAWSClientsProvider(new Provider(amazonSnsSuccess)) - .withResultOutputTag(results)); - - final PCollection publishedResultsSize = snsWrites.get(results).apply(Count.globally()); - PAssert.that(publishedResultsSize).containsInAnyOrder(ImmutableList.of(2L)); - p.run().waitUntilFinish(); - } - - @Rule public ExpectedException thrown = ExpectedException.none(); - - @Test - public void testRetries() throws Throwable { - thrown.expect(IOException.class); - thrown.expectMessage("Error writing to SNS"); - thrown.expectMessage("No more attempts allowed"); - - final PublishRequest request1 = createSampleMessage("my message that will not be published"); - final TupleTag results = new TupleTag<>(); - final AmazonSNS amazonSnsErrors = getAmazonSnsMockErrors(); - p.apply(Create.of(request1)) - .apply( - SnsIO.write() - .withTopicName(topicName) - .withRetryConfiguration( - SnsIO.RetryConfiguration.create(4, standardSeconds(10), millis(1))) - .withAWSClientsProvider(new Provider(amazonSnsErrors)) - .withResultOutputTag(results)); - - try { - p.run(); - } catch (final Pipeline.PipelineExecutionException e) { - // check 3 retries were initiated by inspecting the log before passing on the exception - snsWriterFnLogs.verifyWarn( - MessageFormatter.format(SnsIO.Write.SnsWriterFn.RETRY_ATTEMPT_LOG, 1).getMessage()); - snsWriterFnLogs.verifyWarn( - MessageFormatter.format(SnsIO.Write.SnsWriterFn.RETRY_ATTEMPT_LOG, 2).getMessage()); - snsWriterFnLogs.verifyWarn( - MessageFormatter.format(SnsIO.Write.SnsWriterFn.RETRY_ATTEMPT_LOG, 3).getMessage()); - throw e.getCause(); - } - } - - @Test - public void testCustomCoder() throws Exception { - final PublishRequest request1 = createSampleMessage("my_first_message"); - - final TupleTag results = new TupleTag<>(); - final AmazonSNS amazonSnsSuccess = getAmazonSnsMockSuccess(); - final MockCoder mockCoder = new MockCoder(); - - final PCollectionTuple snsWrites = - p.apply(Create.of(request1)) - .apply( - SnsIO.write() - .withTopicName(topicName) - .withAWSClientsProvider(new Provider(amazonSnsSuccess)) - .withResultOutputTag(results) - .withCoder(mockCoder)); - - final PCollection publishedResultsSize = - snsWrites - .get(results) - .apply(MapElements.into(TypeDescriptors.strings()).via(result -> result.getMessageId())) - .apply(Count.globally()); - PAssert.that(publishedResultsSize).containsInAnyOrder(ImmutableList.of(1L)); - p.run().waitUntilFinish(); - assertThat(mockCoder.captured).isNotNull(); - } - - // Hand-code mock because Mockito mocks cause NotSerializableException even with - // withSettings().serializable(). - private static class MockCoder extends AtomicCoder { - - private PublishResult captured; - - @Override - public void encode(PublishResult value, OutputStream outStream) - throws CoderException, IOException { - this.captured = value; - PublishResultCoders.defaultPublishResult().encode(value, outStream); - } - - @Override - public PublishResult decode(InputStream inStream) throws CoderException, IOException { - return PublishResultCoders.defaultPublishResult().decode(inStream); - } - }; - - private static AmazonSNS getAmazonSnsMockSuccess() { - final AmazonSNS amazonSNS = Mockito.mock(AmazonSNS.class); - configureAmazonSnsMock(amazonSNS); - - final PublishResult result = Mockito.mock(PublishResult.class); - final SdkHttpMetadata metadata = Mockito.mock(SdkHttpMetadata.class); - Mockito.when(metadata.getHttpHeaders()).thenReturn(new HashMap<>()); - Mockito.when(metadata.getHttpStatusCode()).thenReturn(200); - Mockito.when(result.getSdkHttpMetadata()).thenReturn(metadata); - Mockito.when(result.getMessageId()).thenReturn(UUID.randomUUID().toString()); - Mockito.when(amazonSNS.publish(Mockito.any())).thenReturn(result); - return amazonSNS; - } - - private static AmazonSNS getAmazonSnsMockErrors() { - final AmazonSNS amazonSNS = Mockito.mock(AmazonSNS.class); - configureAmazonSnsMock(amazonSNS); - - Mockito.when(amazonSNS.publish(Mockito.any())) - .thenThrow(new InternalErrorException("Service unavailable")); - return amazonSNS; - } - - private static void configureAmazonSnsMock(AmazonSNS amazonSNS) { - final GetTopicAttributesResult result = Mockito.mock(GetTopicAttributesResult.class); - final SdkHttpMetadata metadata = Mockito.mock(SdkHttpMetadata.class); - Mockito.when(metadata.getHttpHeaders()).thenReturn(new HashMap<>()); - Mockito.when(metadata.getHttpStatusCode()).thenReturn(200); - Mockito.when(result.getSdkHttpMetadata()).thenReturn(metadata); - Mockito.when(amazonSNS.getTopicAttributes(Mockito.anyString())).thenReturn(result); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/EmbeddedSqsServer.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/EmbeddedSqsServer.java deleted file mode 100644 index 543df65bbe11..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/EmbeddedSqsServer.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.services.sqs.AmazonSQS; -import com.amazonaws.services.sqs.AmazonSQSClientBuilder; -import com.amazonaws.services.sqs.model.CreateQueueResult; -import org.elasticmq.rest.sqs.SQSRestServer; -import org.elasticmq.rest.sqs.SQSRestServerBuilder; -import org.junit.rules.ExternalResource; - -class EmbeddedSqsServer extends ExternalResource { - - private SQSRestServer sqsRestServer; - private AmazonSQS client; - private String queueUrl; - - @Override - protected void before() { - sqsRestServer = SQSRestServerBuilder.withDynamicPort().start(); - int port = sqsRestServer.waitUntilStarted().localAddress().getPort(); - - String endpoint = String.format("http://localhost:%d", port); - String region = "elasticmq"; - String accessKey = "x"; - String secretKey = "x"; - - client = - AmazonSQSClientBuilder.standard() - .withCredentials( - new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey))) - .withEndpointConfiguration(new AwsClientBuilder.EndpointConfiguration(endpoint, region)) - .build(); - final CreateQueueResult queue = client.createQueue("test"); - queueUrl = queue.getQueueUrl(); - } - - @Override - protected void after() { - sqsRestServer.stopAndWait(); - client.shutdown(); - } - - public AmazonSQS getClient() { - return client; - } - - public String getQueueUrl() { - return queueUrl; - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsIOIT.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsIOIT.java deleted file mode 100644 index a44cb29a1abc..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsIOIT.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import static org.apache.beam.sdk.io.common.TestRow.getExpectedHashForRowCount; -import static org.apache.beam.sdk.values.TypeDescriptors.strings; -import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SQS; - -import com.amazonaws.services.sqs.AmazonSQS; -import com.amazonaws.services.sqs.AmazonSQSClientBuilder; -import com.amazonaws.services.sqs.model.Message; -import com.amazonaws.services.sqs.model.SendMessageRequest; -import java.io.Serializable; -import org.apache.beam.sdk.io.GenerateSequence; -import org.apache.beam.sdk.io.aws.ITEnvironment; -import org.apache.beam.sdk.io.common.HashingFn; -import org.apache.beam.sdk.io.common.TestRow; -import org.apache.beam.sdk.io.common.TestRow.DeterministicallyConstructTestRowFn; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.junit.ClassRule; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.rules.Timeout; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class SqsIOIT { - public interface SqsITOptions extends ITEnvironment.ITOptions {} - - private static final TypeDescriptor requestType = - TypeDescriptor.of(SendMessageRequest.class); - - @ClassRule - public static ITEnvironment env = - new ITEnvironment<>(SQS, SqsITOptions.class, "SQS_PROVIDER=elasticmq"); - - @Rule public Timeout globalTimeout = Timeout.seconds(600); - - @Rule public TestPipeline pipelineWrite = env.createTestPipeline(); - @Rule public TestPipeline pipelineRead = env.createTestPipeline(); - @Rule public SqsQueue sqsQueue = new SqsQueue(); - - @Test - public void testWriteThenRead() { - int rows = env.options().getNumberOfRows(); - - // Write test dataset to SQS. - pipelineWrite - .apply("Generate Sequence", GenerateSequence.from(0).to(rows)) - .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn())) - .apply("Prepare SQS message", MapElements.into(requestType).via(sqsQueue::messageRequest)) - .apply("Write to SQS", SqsIO.write()); - - // Read test dataset from SQS. - PCollection output = - pipelineRead - .apply("Read from SQS", SqsIO.read().withQueueUrl(sqsQueue.url).withMaxNumRecords(rows)) - .apply("Extract body", MapElements.into(strings()).via(Message::getBody)); - - PAssert.thatSingleton(output.apply("Count All", Count.globally())).isEqualTo((long) rows); - - PAssert.that(output.apply(Combine.globally(new HashingFn()).withoutDefaults())) - .containsInAnyOrder(getExpectedHashForRowCount(rows)); - - pipelineWrite.run(); - pipelineRead.run(); - } - - private static class SqsQueue extends ExternalResource implements Serializable { - private transient AmazonSQS client = env.buildClient(AmazonSQSClientBuilder.standard()); - private String url; - - SendMessageRequest messageRequest(TestRow r) { - return new SendMessageRequest(url, r.name()); - } - - @Override - protected void before() { - url = client.createQueue("beam-sqsio-it").getQueueUrl(); - } - - @Override - protected void after() { - client.deleteQueue(url); - client.shutdown(); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsIOTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsIOTest.java deleted file mode 100644 index 23cc56a9438d..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsIOTest.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import com.amazonaws.services.sqs.AmazonSQS; -import com.amazonaws.services.sqs.model.Message; -import com.amazonaws.services.sqs.model.ReceiveMessageResult; -import com.amazonaws.services.sqs.model.SendMessageRequest; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests on {@link SqsIO}. */ -@RunWith(JUnit4.class) -public class SqsIOTest { - - @Rule public TestPipeline pipeline = TestPipeline.create(); - - @Rule public EmbeddedSqsServer embeddedSqsRestServer = new EmbeddedSqsServer(); - - @Test - public void testWrite() { - final AmazonSQS client = embeddedSqsRestServer.getClient(); - final String queueUrl = embeddedSqsRestServer.getQueueUrl(); - - List messages = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - final SendMessageRequest request = new SendMessageRequest(queueUrl, "This is a test " + i); - messages.add(request); - } - pipeline.apply(Create.of(messages)).apply(SqsIO.write()); - pipeline.run().waitUntilFinish(); - - List received = new ArrayList<>(); - while (received.size() < 100) { - final ReceiveMessageResult receiveMessageResult = client.receiveMessage(queueUrl); - - if (receiveMessageResult.getMessages() != null) { - for (Message message : receiveMessageResult.getMessages()) { - received.add(message.getBody()); - } - } - } - assertEquals(100, received.size()); - for (int i = 0; i < 100; i++) { - assertTrue(received.contains("This is a test " + i)); - } - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsMessageCoderTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsMessageCoderTest.java deleted file mode 100644 index 933028306d8b..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsMessageCoderTest.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import static com.amazonaws.services.sqs.model.MessageSystemAttributeName.SentTimestamp; -import static org.apache.beam.sdk.io.aws.sqs.SqsUnboundedReader.REQUEST_TIME; -import static org.assertj.core.api.Assertions.assertThat; - -import com.amazonaws.services.sqs.model.Message; -import com.amazonaws.services.sqs.model.MessageAttributeValue; -import java.util.Random; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.junit.Test; - -public class SqsMessageCoderTest { - - @Test - public void testMessageDecodeEncodeEquals() throws Exception { - Message message = - new Message() - .withMessageId("messageId") - .withReceiptHandle("receiptHandle") - .withBody("body") - .withAttributes( - ImmutableMap.of(SentTimestamp.name(), Long.toString(new Random().nextLong()))) - .withMessageAttributes( - ImmutableMap.of( - REQUEST_TIME, - new MessageAttributeValue() - .withStringValue(Long.toString(new Random().nextLong())))); - - Message clone = CoderUtils.clone(SqsMessageCoder.of(), message); - assertThat(clone).isEqualTo(message); - } - - @Test - public void testVerifyDeterministic() throws Exception { - SqsMessageCoder.of().verifyDeterministic(); // must not throw - } - - @Test - public void testConsistentWithEquals() { - // some attributes might be omitted - assertThat(SqsMessageCoder.of().consistentWithEquals()).isFalse(); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedReaderTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedReaderTest.java deleted file mode 100644 index a6e986251cb6..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedReaderTest.java +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import static junit.framework.TestCase.assertFalse; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import com.amazonaws.services.sqs.AmazonSQS; -import com.amazonaws.services.sqs.model.Message; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.sdk.io.aws.options.AwsOptions; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.util.CoderUtils; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests on {@link SqsUnboundedReader}. */ -@RunWith(JUnit4.class) -public class SqsUnboundedReaderTest { - private static final String DATA = "testData"; - - @Rule public TestPipeline pipeline = TestPipeline.create(); - - @Rule public EmbeddedSqsServer embeddedSqsRestServer = new EmbeddedSqsServer(); - - private SqsUnboundedSource source; - - private void setupOneMessage() { - final AmazonSQS client = embeddedSqsRestServer.getClient(); - final String queueUrl = embeddedSqsRestServer.getQueueUrl(); - client.sendMessage(queueUrl, DATA); - source = - new SqsUnboundedSource( - SqsIO.read().withQueueUrl(queueUrl).withMaxNumRecords(1), - new SqsConfiguration(pipeline.getOptions().as(AwsOptions.class)), - SqsMessageCoder.of()); - } - - private void setupMessages(List messages) { - final AmazonSQS client = embeddedSqsRestServer.getClient(); - final String queueUrl = embeddedSqsRestServer.getQueueUrl(); - for (String message : messages) { - client.sendMessage(queueUrl, message); - } - source = - new SqsUnboundedSource( - SqsIO.read().withQueueUrl(queueUrl).withMaxNumRecords(1), - new SqsConfiguration(pipeline.getOptions().as(AwsOptions.class)), - SqsMessageCoder.of()); - } - - @Test - public void testReadOneMessage() throws IOException { - setupOneMessage(); - UnboundedSource.UnboundedReader reader = - source.createReader(pipeline.getOptions(), null); - // Read one message. - assertTrue(reader.start()); - assertEquals(DATA, reader.getCurrent().getBody()); - assertFalse(reader.advance()); - // ACK the message. - UnboundedSource.CheckpointMark checkpoint = reader.getCheckpointMark(); - checkpoint.finalizeCheckpoint(); - reader.close(); - } - - @Test - public void testTimeoutAckAndRereadOneMessage() throws IOException { - setupOneMessage(); - UnboundedSource.UnboundedReader reader = - source.createReader(pipeline.getOptions(), null); - AmazonSQS sqsClient = embeddedSqsRestServer.getClient(); - assertTrue(reader.start()); - assertEquals(DATA, reader.getCurrent().getBody()); - String receiptHandle = reader.getCurrent().getReceiptHandle(); - // Set the message to timeout. - sqsClient.changeMessageVisibility(source.getRead().queueUrl(), receiptHandle, 0); - // We'll now receive the same message again. - assertTrue(reader.advance()); - assertEquals(DATA, reader.getCurrent().getBody()); - assertFalse(reader.advance()); - // Now ACK the message. - UnboundedSource.CheckpointMark checkpoint = reader.getCheckpointMark(); - checkpoint.finalizeCheckpoint(); - reader.close(); - } - - @Test - public void testMultipleReaders() throws IOException { - List incoming = new ArrayList<>(); - for (int i = 0; i < 2; i++) { - incoming.add(String.format("data_%d", i)); - } - setupMessages(incoming); - UnboundedSource.UnboundedReader reader = - source.createReader(pipeline.getOptions(), null); - // Consume two messages, only read one. - assertTrue(reader.start()); - assertEquals("data_0", reader.getCurrent().getBody()); - - // Grab checkpoint. - SqsCheckpointMark checkpoint = (SqsCheckpointMark) reader.getCheckpointMark(); - checkpoint.finalizeCheckpoint(); - assertEquals(1, checkpoint.notYetReadReceipts.size()); - - // Read second message. - assertTrue(reader.advance()); - assertEquals("data_1", reader.getCurrent().getBody()); - - // Restore from checkpoint. - byte[] checkpointBytes = - CoderUtils.encodeToByteArray(source.getCheckpointMarkCoder(), checkpoint); - checkpoint = CoderUtils.decodeFromByteArray(source.getCheckpointMarkCoder(), checkpointBytes); - assertEquals(1, checkpoint.notYetReadReceipts.size()); - - // Re-read second message. - reader = source.createReader(pipeline.getOptions(), checkpoint); - assertTrue(reader.start()); - assertEquals("data_1", reader.getCurrent().getBody()); - - // We are done. - assertFalse(reader.advance()); - - // ACK final message. - checkpoint = (SqsCheckpointMark) reader.getCheckpointMark(); - checkpoint.finalizeCheckpoint(); - reader.close(); - } - - @Test - public void testReadMany() throws IOException { - - HashSet messages = new HashSet<>(); - List incoming = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - String content = String.format("data_%d", i); - messages.add(content); - incoming.add(String.format("data_%d", i)); - } - setupMessages(incoming); - - SqsUnboundedReader reader = - (SqsUnboundedReader) source.createReader(pipeline.getOptions(), null); - - for (int i = 0; i < 100; i++) { - if (i == 0) { - assertTrue(reader.start()); - } else { - assertTrue(reader.advance()); - } - String data = reader.getCurrent().getBody(); - boolean messageNum = messages.remove(data); - // No duplicate messages. - assertTrue(messageNum); - } - // We are done. - assertFalse(reader.advance()); - // We saw each message exactly once. - assertTrue(messages.isEmpty()); - reader.close(); - } - - /** Tests that checkpoints finalized after the reader is closed succeed. */ - @Test - public void testCloseWithActiveCheckpoints() throws Exception { - setupOneMessage(); - UnboundedSource.UnboundedReader reader = - source.createReader(pipeline.getOptions(), null); - reader.start(); - UnboundedSource.CheckpointMark checkpoint = reader.getCheckpointMark(); - reader.close(); - checkpoint.finalizeCheckpoint(); - } -} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedSourceTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedSourceTest.java deleted file mode 100644 index 58099dc17ee5..000000000000 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sqs/SqsUnboundedSourceTest.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws.sqs; - -import com.amazonaws.services.sqs.AmazonSQS; -import org.apache.beam.sdk.io.aws.options.AwsOptions; -import org.apache.beam.sdk.testing.CoderProperties; -import org.apache.beam.sdk.testing.TestPipeline; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests on {@link SqsUnboundedSource}. */ -@RunWith(JUnit4.class) -public class SqsUnboundedSourceTest { - - private static final String DATA = "testData"; - - @Rule public TestPipeline pipeline = TestPipeline.create(); - - @Rule public EmbeddedSqsServer embeddedSqsRestServer = new EmbeddedSqsServer(); - - @Test - public void testCheckpointCoderIsSane() { - final AmazonSQS client = embeddedSqsRestServer.getClient(); - final String queueUrl = embeddedSqsRestServer.getQueueUrl(); - client.sendMessage(queueUrl, DATA); - SqsUnboundedSource source = - new SqsUnboundedSource( - SqsIO.read().withQueueUrl(queueUrl).withMaxNumRecords(1), - new SqsConfiguration(pipeline.getOptions().as(AwsOptions.class)), - SqsMessageCoder.of()); - CoderProperties.coderSerializable(source.getCheckpointMarkCoder()); - } -} diff --git a/sdks/java/io/amazon-web-services2/build.gradle b/sdks/java/io/amazon-web-services2/build.gradle index b1b5e3abee69..122fa18de895 100644 --- a/sdks/java/io/amazon-web-services2/build.gradle +++ b/sdks/java/io/amazon-web-services2/build.gradle @@ -52,7 +52,7 @@ dependencies { implementation library.java.aws_java_sdk2_http_client_spi, excludeNetty implementation library.java.aws_java_sdk2_apache_client, excludeNetty implementation library.java.aws_java_sdk2_netty_client, excludeNetty - implementation("software.amazon.kinesis:amazon-kinesis-client:2.4.8") { + implementation("software.amazon.kinesis:amazon-kinesis-client:3.0.1") { // Note: The KCL client isn't used. However, unfortunately, some model classes of KCL leak into the // KinesisIO API (KinesisClientRecord, InitialPositionInStream). Additionally, KinesisIO // internally uses KCL utils to generate aggregated messages and de-aggregate them. diff --git a/sdks/java/io/kinesis/expansion-service/build.gradle b/sdks/java/io/amazon-web-services2/expansion-service/build.gradle similarity index 72% rename from sdks/java/io/kinesis/expansion-service/build.gradle rename to sdks/java/io/amazon-web-services2/expansion-service/build.gradle index 3bb7317924d7..fd712737f53c 100644 --- a/sdks/java/io/kinesis/expansion-service/build.gradle +++ b/sdks/java/io/amazon-web-services2/expansion-service/build.gradle @@ -21,19 +21,19 @@ apply plugin: 'application' mainClassName = "org.apache.beam.sdk.expansion.service.ExpansionService" applyJavaNature( - automaticModuleName: 'org.apache.beam.sdk.io.kinesis.expansion.service', + automaticModuleName: 'org.apache.beam.sdk.io.amazon-web-services2.expansion.service', exportJavadoc: false, validateShadowJar: false, shadowClosure: {}, ) -description = "Apache Beam :: SDKs :: Java :: IO :: Kinesis :: Expansion Service" -ext.summary = "Expansion service serving KinesisIO" +description = "Apache Beam :: SDKs :: Java :: IO :: Amazon Web Services 2 :: Expansion Service" +ext.summary = "Expansion service serving AWS2" dependencies { implementation project(":sdks:java:expansion-service") - permitUnusedDeclared project(":sdks:java:expansion-service") // BEAM-11761 - implementation project(":sdks:java:io:kinesis") - permitUnusedDeclared project(":sdks:java:io:kinesis") // BEAM-11761 + permitUnusedDeclared project(":sdks:java:expansion-service") + implementation project(":sdks:java:io:amazon-web-services2") + permitUnusedDeclared project(":sdks:java:io:amazon-web-services2") runtimeOnly library.java.slf4j_jdk14 -} +} \ No newline at end of file diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientBuilderFactory.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientBuilderFactory.java index 6398de57b5c3..8d8531ce5cdf 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientBuilderFactory.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientBuilderFactory.java @@ -24,10 +24,14 @@ import java.io.Serializable; import java.net.URI; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; import java.time.Duration; import java.util.function.Consumer; import java.util.function.Function; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; import org.apache.beam.sdk.io.aws2.options.AwsOptions; import org.apache.beam.sdk.util.InstanceBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; @@ -37,9 +41,12 @@ import software.amazon.awssdk.core.client.builder.SdkSyncClientBuilder; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.http.Protocol; +import software.amazon.awssdk.http.TlsTrustManagersProvider; import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.http.apache.ProxyConfiguration; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.internal.http.NoneTlsKeyManagersProvider; import software.amazon.awssdk.regions.Region; /** @@ -113,6 +120,32 @@ static , ClientT> ClientT b return ClientBuilderFactory.getFactory(options).create(builder, config, options).build(); } + /** Trust provider to skip certificate verification. Should only be used for test pipelines. */ + class SkipCertificateVerificationTrustManagerProvider implements TlsTrustManagersProvider { + public SkipCertificateVerificationTrustManagerProvider() {} + + @Override + public TrustManager[] trustManagers() { + TrustManager tm = + new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] x509CertificateArr, String str) + throws CertificateException {} + + @Override + public void checkServerTrusted(X509Certificate[] x509CertificateArr, String str) + throws CertificateException {} + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }; + TrustManager[] tms = {tm}; + return tms; + } + } + /** * Default implementation of {@link ClientBuilderFactory}. This implementation can configure both, * synchronous clients using {@link ApacheHttpClient} as well as asynchronous clients using {@link @@ -161,7 +194,11 @@ public , ClientT> BuilderT HttpClientConfiguration httpConfig = options.getHttpClientConfiguration(); ProxyConfiguration proxyConfig = options.getProxyConfiguration(); - if (proxyConfig != null || httpConfig != null) { + boolean skipCertificateVerification = false; + if (config.skipCertificateVerification() != null) { + skipCertificateVerification = config.skipCertificateVerification(); + } + if (proxyConfig != null || httpConfig != null || skipCertificateVerification) { if (builder instanceof SdkSyncClientBuilder) { ApacheHttpClient.Builder client = syncClientBuilder(); @@ -177,6 +214,11 @@ public , ClientT> BuilderT setOptional(httpConfig.maxConnections(), client::maxConnections); } + if (skipCertificateVerification) { + client.tlsKeyManagersProvider(NoneTlsKeyManagersProvider.getInstance()); + client.tlsTrustManagersProvider(new SkipCertificateVerificationTrustManagerProvider()); + } + // must use builder to make sure client is managed by the SDK ((SdkSyncClientBuilder) builder).httpClientBuilder(client); } else if (builder instanceof SdkAsyncClientBuilder) { @@ -201,6 +243,12 @@ public , ClientT> BuilderT setOptional(httpConfig.maxConnections(), client::maxConcurrency); } + if (skipCertificateVerification) { + client.tlsKeyManagersProvider(NoneTlsKeyManagersProvider.getInstance()); + client.tlsTrustManagersProvider(new SkipCertificateVerificationTrustManagerProvider()); + client.protocol(Protocol.HTTP1_1); + } + // must use builder to make sure client is managed by the SDK ((SdkAsyncClientBuilder) builder).httpClientBuilder(client); } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java index 08fb595bd037..385a25b5a13f 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java @@ -76,6 +76,13 @@ public abstract class ClientConfiguration implements Serializable { return regionId() != null ? Region.of(regionId()) : null; } + /** + * Optional flag to skip certificate verification. Should only be overriden for test scenarios. If + * set, this overwrites the default in {@link AwsOptions#skipCertificateVerification()}. + */ + @JsonProperty + public abstract @Nullable @Pure Boolean skipCertificateVerification(); + /** * Optional service endpoint to use AWS compatible services instead, e.g. for testing. If set, * this overwrites the default in {@link AwsOptions#getEndpoint()}. @@ -156,6 +163,13 @@ public Builder retry(Consumer retry) { return retry(builder.build()); } + /** + * Optional flag to skip certificate verification. Should only be overriden for test scenarios. + * If set, this overwrites the default in {@link AwsOptions#skipCertificateVerification()}. + */ + @JsonProperty + public abstract Builder skipCertificateVerification(boolean skipCertificateVerification); + abstract Builder regionId(String region); abstract Builder credentialsProviderAsJson(String credentialsProvider); diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisTransformRegistrar.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisTransformRegistrar.java similarity index 64% rename from sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisTransformRegistrar.java rename to sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisTransformRegistrar.java index b8e1a38c73ff..51d4202e4027 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisTransformRegistrar.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisTransformRegistrar.java @@ -15,35 +15,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.kinesis; +package org.apache.beam.sdk.io.aws2.kinesis; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.google.auto.service.AutoService; +import java.net.URI; +import java.net.URISyntaxException; import java.util.Map; -import java.util.Properties; import org.apache.beam.sdk.expansion.ExternalTransformRegistrar; +import org.apache.beam.sdk.io.aws2.common.ClientConfiguration; +import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ExternalTransformBuilder; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PDone; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.joda.time.Instant; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.kinesis.common.InitialPositionInStream; /** - * Exposes {@link KinesisIO.Write} and {@link KinesisIO.Read} as an external transform for - * cross-language usage. + * Exposes {@link org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write} and {@link + * org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Read} as an external transform for cross-language + * usage. */ @AutoService(ExternalTransformRegistrar.class) @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class KinesisTransformRegistrar implements ExternalTransformRegistrar { - public static final String WRITE_URN = "beam:transform:org.apache.beam:kinesis_write:v1"; - public static final String READ_DATA_URN = "beam:transform:org.apache.beam:kinesis_read_data:v1"; + public static final String WRITE_URN = "beam:transform:org.apache.beam:kinesis_write:v2"; + public static final String READ_DATA_URN = "beam:transform:org.apache.beam:kinesis_read_data:v2"; @Override public Map> knownBuilderInstances() { @@ -54,7 +62,7 @@ private abstract static class CrossLanguageConfiguration { String streamName; String awsAccessKey; String awsSecretKey; - Regions region; + String region; @Nullable String serviceEndpoint; boolean verifyCertificate; @@ -71,7 +79,7 @@ public void setAwsSecretKey(String awsSecretKey) { } public void setRegion(String region) { - this.region = Regions.fromName(region); + this.region = region; } public void setServiceEndpoint(@Nullable String serviceEndpoint) { @@ -84,41 +92,48 @@ public void setVerifyCertificate(@Nullable Boolean verifyCertificate) { } public static class WriteBuilder - implements ExternalTransformBuilder, PDone> { + implements ExternalTransformBuilder< + WriteBuilder.Configuration, PCollection, KinesisIO.Write.Result> { public static class Configuration extends CrossLanguageConfiguration { - private Properties producerProperties; private String partitionKey; - public void setProducerProperties(Map producerProperties) { - if (producerProperties != null) { - Properties properties = new Properties(); - producerProperties.forEach(properties::setProperty); - this.producerProperties = properties; - } - } - public void setPartitionKey(String partitionKey) { this.partitionKey = partitionKey; } } @Override - public PTransform, PDone> buildExternal(Configuration configuration) { - KinesisIO.Write writeTransform = - KinesisIO.write() - .withStreamName(configuration.streamName) - .withAWSClientsProvider( - configuration.awsAccessKey, - configuration.awsSecretKey, - configuration.region, - configuration.serviceEndpoint, - configuration.verifyCertificate) - .withPartitionKey(configuration.partitionKey); - - if (configuration.producerProperties != null) { - writeTransform = writeTransform.withProducerProperties(configuration.producerProperties); + public PTransform, KinesisIO.Write.Result> buildExternal( + Configuration configuration) { + AwsBasicCredentials creds = + AwsBasicCredentials.create(configuration.awsAccessKey, configuration.awsSecretKey); + String pk = configuration.partitionKey; + StaticCredentialsProvider provider = StaticCredentialsProvider.create(creds); + SerializableFunction serializer = v -> v; + @Nullable URI endpoint = null; + if (configuration.serviceEndpoint != null) { + try { + endpoint = new URI(configuration.serviceEndpoint); + } catch (URISyntaxException ex) { + throw new RuntimeException( + String.format( + "Service endpoint must be URI format, got: %s", configuration.serviceEndpoint)); + } } + KinesisIO.Write writeTransform = + KinesisIO.write() + .withStreamName(configuration.streamName) + .withClientConfiguration( + ClientConfiguration.builder() + .credentialsProvider(provider) + .region(Region.of(configuration.region)) + .endpoint(endpoint) + .skipCertificateVerification(!configuration.verifyCertificate) + .build()) + .withPartitioner(p -> pk) + .withRecordAggregationDisabled() + .withSerializer(serializer); return writeTransform; } @@ -205,15 +220,29 @@ private enum WatermarkPolicy { @Override public PTransform> buildExternal( ReadDataBuilder.Configuration configuration) { - KinesisIO.Read readTransform = - KinesisIO.readData() + AwsBasicCredentials creds = + AwsBasicCredentials.create(configuration.awsAccessKey, configuration.awsSecretKey); + StaticCredentialsProvider provider = StaticCredentialsProvider.create(creds); + @Nullable URI endpoint = null; + if (configuration.serviceEndpoint != null) { + try { + endpoint = new URI(configuration.serviceEndpoint); + } catch (URISyntaxException ex) { + throw new RuntimeException( + String.format( + "Service endpoint must be URI format, got: %s", configuration.serviceEndpoint)); + } + } + KinesisIO.Read readTransform = + KinesisIO.read() .withStreamName(configuration.streamName) - .withAWSClientsProvider( - configuration.awsAccessKey, - configuration.awsSecretKey, - configuration.region, - configuration.serviceEndpoint, - configuration.verifyCertificate); + .withClientConfiguration( + ClientConfiguration.builder() + .credentialsProvider(provider) + .region(Region.of(configuration.region)) + .endpoint(endpoint) + .skipCertificateVerification(!configuration.verifyCertificate) + .build()); if (configuration.maxNumRecords != null) { readTransform = readTransform.withMaxNumRecords(configuration.maxNumRecords); @@ -260,7 +289,34 @@ public PTransform> buildExternal( readTransform = readTransform.withInitialTimestampInStream(configuration.initialTimestampInStream); } - return readTransform; + + return new KinesisReadToBytes(readTransform); + } + } + + public static class KinesisReadToBytes extends PTransform> { + private KinesisIO.Read readTransform; + + private KinesisReadToBytes(KinesisIO.Read readTransform) { + this.readTransform = readTransform; + } + + @Override + public PCollection expand(PBegin input) { + // Convert back to bytes to keep consistency with previous verison: + // https://github.com/apache/beam/blob/5eed396caf9e0065d8ed82edcc236bad5b71ba22/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisTransformRegistrar.java + return input + .apply(this.readTransform) + .apply( + "Convert to bytes", + ParDo.of( + new DoFn() { + @ProcessElement + public byte[] processElement(ProcessContext c) { + KinesisRecord record = c.element(); + return record.getDataAsBytes(); + } + })); } } } diff --git a/sdks/java/io/expansion-service/build.gradle b/sdks/java/io/expansion-service/build.gradle index cc8eccf98997..38bee450e752 100644 --- a/sdks/java/io/expansion-service/build.gradle +++ b/sdks/java/io/expansion-service/build.gradle @@ -34,6 +34,9 @@ configurations.runtimeClasspath { } shadowJar { + manifest { + attributes(["Multi-Release": true]) + } mergeServiceFiles() outputs.upToDateWhen { false } } @@ -54,12 +57,13 @@ dependencies { permitUnusedDeclared project(":sdks:java:io:kafka:upgrade") // BEAM-11761 // **** IcebergIO runtime dependencies **** + runtimeOnly library.java.hadoop_auth runtimeOnly library.java.hadoop_client // Needed when using GCS as the warehouse location. runtimeOnly library.java.bigdataoss_gcs_connector // Needed for HiveCatalog runtimeOnly ("org.apache.iceberg:iceberg-hive-metastore:1.4.2") - runtimeOnly project(path: ":sdks:java:io:iceberg:hive:exec", configuration: "shadow") + runtimeOnly project(path: ":sdks:java:io:iceberg:hive") runtimeOnly library.java.kafka_clients runtimeOnly library.java.slf4j_jdk14 diff --git a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle index 01181721e9a4..b5ce11853f6c 100644 --- a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle +++ b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle @@ -49,5 +49,8 @@ task runExpansionService (type: JavaExec) { } shadowJar { + manifest { + attributes(["Multi-Release": true]) + } outputs.upToDateWhen { false } } \ No newline at end of file diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroGenericRecordToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroGenericRecordToStorageApiProto.java index 0b7e17b89090..c721395eec79 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroGenericRecordToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroGenericRecordToStorageApiProto.java @@ -86,7 +86,7 @@ public class AvroGenericRecordToStorageApiProto { .put(Schema.Type.STRING, Object::toString) .put(Schema.Type.BOOLEAN, Function.identity()) .put(Schema.Type.ENUM, o -> o.toString()) - .put(Schema.Type.BYTES, o -> ByteString.copyFrom((byte[]) o)) + .put(Schema.Type.BYTES, o -> ByteString.copyFrom(((ByteBuffer) o).duplicate())) .build(); // A map of supported logical types to their encoding functions. diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java index cddde05b194c..1af44ba7a012 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java @@ -34,6 +34,8 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Set; import org.apache.avro.Conversions; import org.apache.avro.LogicalType; @@ -50,14 +52,14 @@ import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -/** - * A set of utilities for working with Avro files. - * - *

These utilities are based on the Avro - * 1.8.1 specification. - */ +/** A set of utilities for working with Avro files. */ class BigQueryAvroUtils { + private static final String VERSION_AVRO = + Optional.ofNullable(Schema.class.getPackage()) + .map(Package::getImplementationVersion) + .orElse(""); + // org.apache.avro.LogicalType static class DateTimeLogicalType extends LogicalType { public DateTimeLogicalType() { @@ -74,6 +76,8 @@ public DateTimeLogicalType() { * export * @see BQ * avro storage + * @see BQ avro + * load */ static Schema getPrimitiveType(TableFieldSchema schema, Boolean useAvroLogicalTypes) { String bqType = schema.getType(); @@ -116,6 +120,9 @@ static Schema getPrimitiveType(TableFieldSchema schema, Boolean useAvroLogicalTy } case "DATETIME": if (useAvroLogicalTypes) { + // BQ export uses a custom logical type + // TODO for load/storage use + // LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType()) return DATETIME_LOGICAL_TYPE.addToSchema(SchemaBuilder.builder().stringType()); } else { return SchemaBuilder.builder().stringBuilder().prop("sqlType", bqType).endString(); @@ -158,6 +165,12 @@ static Schema getPrimitiveType(TableFieldSchema schema, Boolean useAvroLogicalTy @VisibleForTesting static String formatTimestamp(Long timestampMicro) { + String dateTime = formatDatetime(timestampMicro); + return dateTime + " UTC"; + } + + @VisibleForTesting + static String formatDatetime(Long timestampMicro) { // timestampMicro is in "microseconds since epoch" format, // e.g., 1452062291123456L means "2016-01-06 06:38:11.123456 UTC". // Separate into seconds and microseconds. @@ -168,11 +181,13 @@ static String formatTimestamp(Long timestampMicro) { timestampSec -= 1; } String dayAndTime = DATE_AND_SECONDS_FORMATTER.print(timestampSec * 1000); - if (micros == 0) { - return String.format("%s UTC", dayAndTime); + return dayAndTime; + } else if (micros % 1000 == 0) { + return String.format("%s.%03d", dayAndTime, micros / 1000); + } else { + return String.format("%s.%06d", dayAndTime, micros); } - return String.format("%s.%06d UTC", dayAndTime, micros); } /** @@ -274,8 +289,7 @@ static TableRow convertGenericRecordToTableRow(GenericRecord record) { case UNION: return convertNullableField(name, schema, v); case MAP: - throw new UnsupportedOperationException( - String.format("Unexpected Avro field schema type %s for field named %s", type, name)); + return convertMapField(name, schema, v); default: return convertRequiredField(name, schema, v); } @@ -296,6 +310,26 @@ private static List convertRepeatedField(String name, Schema elementType return values; } + private static List convertMapField(String name, Schema map, Object v) { + // Avro maps are represented as key/value RECORD. + if (v == null) { + // Handle the case of an empty map. + return new ArrayList<>(); + } + + Schema type = map.getValueType(); + Map elements = (Map) v; + ArrayList values = new ArrayList<>(); + for (Map.Entry element : elements.entrySet()) { + TableRow row = + new TableRow() + .set("key", element.getKey()) + .set("value", convertRequiredField(name, type, element.getValue())); + values.add(row); + } + return values; + } + private static Object convertRequiredField(String name, Schema schema, Object v) { // REQUIRED fields are represented as the corresponding Avro types. For example, a BigQuery // INTEGER type maps to an Avro LONG type. @@ -305,45 +339,83 @@ private static Object convertRequiredField(String name, Schema schema, Object v) LogicalType logicalType = schema.getLogicalType(); switch (type) { case BOOLEAN: - // SQL types BOOL, BOOLEAN + // SQL type BOOL (BOOLEAN) return v; case INT: if (logicalType instanceof LogicalTypes.Date) { - // SQL types DATE + // SQL type DATE + // ideally LocalDate but TableRowJsonCoder encodes as String return formatDate((Integer) v); + } else if (logicalType instanceof LogicalTypes.TimeMillis) { + // Write only: SQL type TIME + // ideally LocalTime but TableRowJsonCoder encodes as String + return formatTime(((Integer) v) * 1000L); } else { - throw new UnsupportedOperationException( - String.format("Unexpected Avro field schema type %s for field named %s", type, name)); + // Write only: SQL type INT64 (INT, SMALLINT, INTEGER, BIGINT, TINYINT, BYTEINT) + // ideally Integer but keep consistency with BQ JSON export that uses String + return ((Integer) v).toString(); } case LONG: if (logicalType instanceof LogicalTypes.TimeMicros) { - // SQL types TIME + // SQL type TIME + // ideally LocalTime but TableRowJsonCoder encodes as String return formatTime((Long) v); + } else if (logicalType instanceof LogicalTypes.TimestampMillis) { + // Write only: SQL type TIMESTAMP + // ideally Instant but TableRowJsonCoder encodes as String + return formatTimestamp((Long) v * 1000L); } else if (logicalType instanceof LogicalTypes.TimestampMicros) { - // SQL types TIMESTAMP + // SQL type TIMESTAMP + // ideally Instant but TableRowJsonCoder encodes as String return formatTimestamp((Long) v); + } else if (!(VERSION_AVRO.startsWith("1.8") || VERSION_AVRO.startsWith("1.9")) + && logicalType instanceof LogicalTypes.LocalTimestampMillis) { + // Write only: SQL type DATETIME + // ideally LocalDateTime but TableRowJsonCoder encodes as String + return formatDatetime(((Long) v) * 1000); + } else if (!(VERSION_AVRO.startsWith("1.8") || VERSION_AVRO.startsWith("1.9")) + && logicalType instanceof LogicalTypes.LocalTimestampMicros) { + // Write only: SQL type DATETIME + // ideally LocalDateTime but TableRowJsonCoder encodes as String + return formatDatetime((Long) v); } else { - // SQL types INT64 (INT, SMALLINT, INTEGER, BIGINT, TINYINT, BYTEINT) + // SQL type INT64 (INT, SMALLINT, INTEGER, BIGINT, TINYINT, BYTEINT) + // ideally Long if in [2^53+1, 2^53-1] but keep consistency with BQ JSON export that uses + // String return ((Long) v).toString(); } + case FLOAT: + // Write only: SQL type FLOAT64 + // ideally Float but TableRowJsonCoder decodes as Double + return Double.valueOf(v.toString()); case DOUBLE: - // SQL types FLOAT64 + // SQL type FLOAT64 return v; case BYTES: if (logicalType instanceof LogicalTypes.Decimal) { // SQL tpe NUMERIC, BIGNUMERIC + // ideally BigDecimal but TableRowJsonCoder encodes as String return new Conversions.DecimalConversion() .fromBytes((ByteBuffer) v, schema, logicalType) .toString(); } else { - // SQL types BYTES + // SQL type BYTES + // ideally byte[] but TableRowJsonCoder encodes as String return BaseEncoding.base64().encode(((ByteBuffer) v).array()); } case STRING: // SQL types STRING, DATETIME, GEOGRAPHY, JSON // when not using logical type DATE, TIME too return v.toString(); + case ENUM: + // SQL types STRING + return v.toString(); + case FIXED: + // SQL type BYTES + // ideally byte[] but TableRowJsonCoder encodes as String + return BaseEncoding.base64().encode(((ByteBuffer) v).array()); case RECORD: + // SQL types RECORD return convertGenericRecordToTableRow((GenericRecord) v); default: throw new UnsupportedOperationException( diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index 9a7f3a05556c..ca9dfdb65caf 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -54,6 +54,7 @@ import java.io.IOException; import java.io.Serializable; import java.lang.reflect.InvocationTargetException; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -629,6 +630,9 @@ public class BigQueryIO { private static final SerializableFunction DEFAULT_AVRO_SCHEMA_FACTORY = BigQueryAvroUtils::toGenericAvroSchema; + static final String CONNECTION_ID = "connectionId"; + static final String STORAGE_URI = "storageUri"; + /** * @deprecated Use {@link #read(SerializableFunction)} or {@link #readTableRows} instead. {@link * #readTableRows()} does exactly the same as {@link #read}, however {@link @@ -2372,6 +2376,8 @@ public enum Method { /** Table description. Default is empty. */ abstract @Nullable String getTableDescription(); + abstract @Nullable Map getBigLakeConfiguration(); + /** An option to indicate if table validation is desired. Default is true. */ abstract boolean getValidate(); @@ -2484,6 +2490,8 @@ abstract Builder setAvroSchemaFactory( abstract Builder setTableDescription(String tableDescription); + abstract Builder setBigLakeConfiguration(Map bigLakeConfiguration); + abstract Builder setValidate(boolean validate); abstract Builder setBigQueryServices(BigQueryServices bigQueryServices); @@ -2909,6 +2917,30 @@ public Write withTableDescription(String tableDescription) { return toBuilder().setTableDescription(tableDescription).build(); } + /** + * Specifies a configuration to create BigLake tables. The following options are available: + * + *
    + *
  • connectionId (REQUIRED): the name of your cloud resource connection. + *
  • storageUri (REQUIRED): the path to your GCS folder where data will be written to. This + * sink will create sub-folders for each project, dataset, and table destination. Example: + * if you specify a storageUri of {@code "gs://foo/bar"} and writing to table {@code + * "my_project.my_dataset.my_table"}, your data will be written under {@code + * "gs://foo/bar/my_project/my_dataset/my_table/"} + *
  • fileFormat (OPTIONAL): defaults to {@code "parquet"} + *
  • tableFormat (OPTIONAL): defaults to {@code "iceberg"} + *
+ * + *

NOTE: This is only supported with the Storage Write API methods. + * + * @see BigQuery Tables for + * Apache Iceberg documentation + */ + public Write withBigLakeConfiguration(Map bigLakeConfiguration) { + checkArgument(bigLakeConfiguration != null, "bigLakeConfiguration can not be null"); + return toBuilder().setBigLakeConfiguration(bigLakeConfiguration).build(); + } + /** * Specifies a policy for handling failed inserts. * @@ -3008,9 +3040,14 @@ public Write withNumFileShards(int numFileShards) { } /** - * Control how many parallel streams are used when using Storage API writes. Applicable only for - * streaming pipelines, and when {@link #withTriggeringFrequency} is also set. To let runner - * determine the sharding at runtime, set this to zero, or {@link #withAutoSharding()} instead. + * Control how many parallel streams are used when using Storage API writes. + * + *

For streaming pipelines, and when {@link #withTriggeringFrequency} is also set. To let + * runner determine the sharding at runtime, set this to zero, or {@link #withAutoSharding()} + * instead. + * + *

For batch pipelines, it inserts a redistribute. To not reshufle and keep the pipeline + * parallelism as is, set this to zero. */ public Write withNumStorageWriteApiStreams(int numStorageWriteApiStreams) { return toBuilder().setNumStorageWriteApiStreams(numStorageWriteApiStreams).build(); @@ -3454,8 +3491,21 @@ && getStorageApiTriggeringFrequency(bqOptions) != null) { checkArgument( !getAutoSchemaUpdate(), "withAutoSchemaUpdate only supported when using STORAGE_WRITE_API or STORAGE_API_AT_LEAST_ONCE."); - } else if (getWriteDisposition() == WriteDisposition.WRITE_TRUNCATE) { - LOG.error("The Storage API sink does not support the WRITE_TRUNCATE write disposition."); + checkArgument( + getBigLakeConfiguration() == null, + "bigLakeConfiguration is only supported when using STORAGE_WRITE_API or STORAGE_API_AT_LEAST_ONCE."); + } else { + if (getWriteDisposition() == WriteDisposition.WRITE_TRUNCATE) { + LOG.error("The Storage API sink does not support the WRITE_TRUNCATE write disposition."); + } + if (getBigLakeConfiguration() != null) { + checkArgument( + Arrays.stream(new String[] {CONNECTION_ID, STORAGE_URI}) + .allMatch(getBigLakeConfiguration()::containsKey), + String.format( + "bigLakeConfiguration must contain keys '%s' and '%s'", + CONNECTION_ID, STORAGE_URI)); + } } if (getRowMutationInformationFn() != null) { checkArgument( @@ -3905,6 +3955,7 @@ private WriteResult continueExpandTyped( getPropagateSuccessfulStorageApiWritesPredicate(), getRowMutationInformationFn() != null, getDefaultMissingValueInterpretation(), + getBigLakeConfiguration(), getBadRecordRouter(), getBadRecordErrorHandler()); return input.apply("StorageApiLoads", storageApiLoads); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java index 561f5ccfc457..1da47156dda7 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java @@ -393,6 +393,7 @@ static class BigQueryIOWriteTranslator implements TransformPayloadTranslator transform) { if (transform.getTableDescription() != null) { fieldValues.put("table_description", transform.getTableDescription()); } + if (transform.getBigLakeConfiguration() != null) { + fieldValues.put("biglake_configuration", transform.getBigLakeConfiguration()); + } fieldValues.put("validate", transform.getValidate()); if (transform.getBigQueryServices() != null) { fieldValues.put("bigquery_services", toByteArray(transform.getBigQueryServices())); @@ -719,6 +723,10 @@ public Write fromConfigRow(Row configRow, PipelineOptions options) { if (tableDescription != null) { builder = builder.setTableDescription(tableDescription); } + Map biglakeConfiguration = configRow.getMap("biglake_configuration"); + if (biglakeConfiguration != null) { + builder = builder.setBigLakeConfiguration(biglakeConfiguration); + } Boolean validate = configRow.getBoolean("validate"); if (validate != null) { builder = builder.setValidate(validate); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTableHelpers.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTableHelpers.java index 7a94657107ec..7c428917503f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTableHelpers.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTableHelpers.java @@ -17,11 +17,14 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.CONNECTION_ID; +import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.STORAGE_URI; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import com.google.api.client.util.BackOff; import com.google.api.client.util.BackOffUtils; import com.google.api.gax.rpc.ApiException; +import com.google.api.services.bigquery.model.BigLakeConfiguration; import com.google.api.services.bigquery.model.Clustering; import com.google.api.services.bigquery.model.EncryptionConfiguration; import com.google.api.services.bigquery.model.Table; @@ -31,6 +34,7 @@ import com.google.api.services.bigquery.model.TimePartitioning; import io.grpc.StatusRuntimeException; import java.util.Collections; +import java.util.Map; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; @@ -41,6 +45,7 @@ import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.checkerframework.checker.nullness.qual.Nullable; @@ -91,7 +96,8 @@ static TableDestination possiblyCreateTable( CreateDisposition createDisposition, @Nullable Coder tableDestinationCoder, @Nullable String kmsKey, - BigQueryServices bqServices) { + BigQueryServices bqServices, + @Nullable Map bigLakeConfiguration) { checkArgument( tableDestination.getTableSpec() != null, "DynamicDestinations.getTable() must return a TableDestination " @@ -132,7 +138,8 @@ static TableDestination possiblyCreateTable( createDisposition, tableSpec, kmsKey, - bqServices); + bqServices, + bigLakeConfiguration); } } } @@ -147,7 +154,8 @@ private static void tryCreateTable( CreateDisposition createDisposition, String tableSpec, @Nullable String kmsKey, - BigQueryServices bqServices) { + BigQueryServices bqServices, + @Nullable Map bigLakeConfiguration) { TableReference tableReference = tableDestination.getTableReference().clone(); tableReference.setTableId(BigQueryHelpers.stripPartitionDecorator(tableReference.getTableId())); try (DatasetService datasetService = bqServices.getDatasetService(options)) { @@ -189,6 +197,24 @@ private static void tryCreateTable( if (kmsKey != null) { table.setEncryptionConfiguration(new EncryptionConfiguration().setKmsKeyName(kmsKey)); } + if (bigLakeConfiguration != null) { + TableReference ref = table.getTableReference(); + table.setBiglakeConfiguration( + new BigLakeConfiguration() + .setTableFormat( + MoreObjects.firstNonNull(bigLakeConfiguration.get("tableFormat"), "iceberg")) + .setFileFormat( + MoreObjects.firstNonNull(bigLakeConfiguration.get("fileFormat"), "parquet")) + .setConnectionId( + Preconditions.checkArgumentNotNull(bigLakeConfiguration.get(CONNECTION_ID))) + .setStorageUri( + String.format( + "%s/%s/%s/%s", + Preconditions.checkArgumentNotNull(bigLakeConfiguration.get(STORAGE_URI)), + ref.getProjectId(), + ref.getDatasetId(), + ref.getTableId()))); + } datasetService.createTable(table); } } catch (Exception e) { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTables.java index 1bbd4e756084..7008c049a4a5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTables.java @@ -132,7 +132,8 @@ public void processElement(ProcessContext context) { createDisposition, dynamicDestinations.getDestinationCoder(), kmsKey, - bqServices); + bqServices, + null); }); context.output(KV.of(tableDestination, context.element().getValue())); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index 4ca9d5035c81..fcf67a8062ac 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -23,6 +23,7 @@ import com.google.cloud.bigquery.storage.v1.AppendRowsRequest; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import java.util.function.Predicate; import javax.annotation.Nullable; @@ -35,6 +36,7 @@ import org.apache.beam.sdk.transforms.GroupIntoBatches; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Redistribute; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.errorhandling.BadRecord; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; @@ -76,6 +78,7 @@ public class StorageApiLoads private final boolean usesCdc; private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + private final Map bigLakeConfiguration; private final BadRecordRouter badRecordRouter; @@ -98,6 +101,7 @@ public StorageApiLoads( Predicate propagateSuccessfulStorageApiWritesPredicate, boolean usesCdc, AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + Map bigLakeConfiguration, BadRecordRouter badRecordRouter, ErrorHandler badRecordErrorHandler) { this.destinationCoder = destinationCoder; @@ -118,6 +122,7 @@ public StorageApiLoads( this.successfulRowsPredicate = propagateSuccessfulStorageApiWritesPredicate; this.usesCdc = usesCdc; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.bigLakeConfiguration = bigLakeConfiguration; this.badRecordRouter = badRecordRouter; this.badRecordErrorHandler = badRecordErrorHandler; } @@ -186,7 +191,8 @@ public WriteResult expandInconsistent( createDisposition, kmsKey, usesCdc, - defaultMissingValueInterpretation)); + defaultMissingValueInterpretation, + bigLakeConfiguration)); PCollection insertErrors = PCollectionList.of(convertMessagesResult.get(failedRowsTag)) @@ -279,7 +285,8 @@ public WriteResult expandTriggered( successfulRowsPredicate, autoUpdateSchema, ignoreUnknownValues, - defaultMissingValueInterpretation)); + defaultMissingValueInterpretation, + bigLakeConfiguration)); PCollection insertErrors = PCollectionList.of(convertMessagesResult.get(failedRowsTag)) @@ -354,25 +361,35 @@ public WriteResult expandUntriggered( rowUpdateFn, badRecordRouter)); + PCollection> successfulConvertedRows = + convertMessagesResult.get(successfulConvertedRowsTag); + + if (numShards > 0) { + successfulConvertedRows = + successfulConvertedRows.apply( + "ResdistibuteNumShards", + Redistribute.>arbitrarily() + .withNumBuckets(numShards)); + } + PCollectionTuple writeRecordsResult = - convertMessagesResult - .get(successfulConvertedRowsTag) - .apply( - "StorageApiWriteUnsharded", - new StorageApiWriteUnshardedRecords<>( - dynamicDestinations, - bqServices, - failedRowsTag, - successfulWrittenRowsTag, - successfulRowsPredicate, - BigQueryStorageApiInsertErrorCoder.of(), - TableRowJsonCoder.of(), - autoUpdateSchema, - ignoreUnknownValues, - createDisposition, - kmsKey, - usesCdc, - defaultMissingValueInterpretation)); + successfulConvertedRows.apply( + "StorageApiWriteUnsharded", + new StorageApiWriteUnshardedRecords<>( + dynamicDestinations, + bqServices, + failedRowsTag, + successfulWrittenRowsTag, + successfulRowsPredicate, + BigQueryStorageApiInsertErrorCoder.of(), + TableRowJsonCoder.of(), + autoUpdateSchema, + ignoreUnknownValues, + createDisposition, + kmsKey, + usesCdc, + defaultMissingValueInterpretation, + bigLakeConfiguration)); PCollection insertErrors = PCollectionList.of(convertMessagesResult.get(failedRowsTag)) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java index 0860b4eda8a2..58bbed8ba5a9 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java @@ -19,6 +19,7 @@ import com.google.api.services.bigquery.model.TableRow; import com.google.cloud.bigquery.storage.v1.AppendRowsRequest; +import java.util.Map; import java.util.function.Predicate; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; @@ -55,6 +56,7 @@ public class StorageApiWriteRecordsInconsistent private final @Nullable String kmsKey; private final boolean usesCdc; private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + private final @Nullable Map bigLakeConfiguration; public StorageApiWriteRecordsInconsistent( StorageApiDynamicDestinations dynamicDestinations, @@ -69,7 +71,8 @@ public StorageApiWriteRecordsInconsistent( BigQueryIO.Write.CreateDisposition createDisposition, @Nullable String kmsKey, boolean usesCdc, - AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + @Nullable Map bigLakeConfiguration) { this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; this.failedRowsTag = failedRowsTag; @@ -83,6 +86,7 @@ public StorageApiWriteRecordsInconsistent( this.kmsKey = kmsKey; this.usesCdc = usesCdc; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.bigLakeConfiguration = bigLakeConfiguration; } @Override @@ -116,7 +120,8 @@ public PCollectionTuple expand(PCollection private final @Nullable String kmsKey; private final boolean usesCdc; private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + private final @Nullable Map bigLakeConfiguration; /** * The Guava cache object is thread-safe. However our protocol requires that client pin the @@ -179,7 +180,8 @@ public StorageApiWriteUnshardedRecords( BigQueryIO.Write.CreateDisposition createDisposition, @Nullable String kmsKey, boolean usesCdc, - AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + @Nullable Map bigLakeConfiguration) { this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; this.failedRowsTag = failedRowsTag; @@ -193,6 +195,7 @@ public StorageApiWriteUnshardedRecords( this.kmsKey = kmsKey; this.usesCdc = usesCdc; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.bigLakeConfiguration = bigLakeConfiguration; } @Override @@ -228,7 +231,8 @@ public PCollectionTuple expand(PCollection bigLakeConfiguration; WriteRecordsDoFn( String operationName, @@ -973,7 +978,8 @@ void postFlush() { @Nullable String kmsKey, boolean usesCdc, AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, - int maxRetries) { + int maxRetries, + @Nullable Map bigLakeConfiguration) { this.messageConverters = new TwoLevelMessageConverterCache<>(operationName); this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; @@ -992,6 +998,7 @@ void postFlush() { this.usesCdc = usesCdc; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; this.maxRetries = maxRetries; + this.bigLakeConfiguration = bigLakeConfiguration; } boolean shouldFlush() { @@ -1098,7 +1105,8 @@ DestinationState createDestinationState( createDisposition, destinationCoder, kmsKey, - bqServices); + bqServices, + bigLakeConfiguration); return true; }; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java index e2674fe34f2e..738a52b69cb7 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java @@ -131,6 +131,7 @@ public class StorageApiWritesShardedRecords bigLakeConfiguration; private final Duration streamIdleTime = DEFAULT_STREAM_IDLE_TIME; private final TupleTag failedRowsTag; @@ -232,7 +233,8 @@ public StorageApiWritesShardedRecords( Predicate successfulRowsPredicate, boolean autoUpdateSchema, boolean ignoreUnknownValues, - AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + @Nullable Map bigLakeConfiguration) { this.dynamicDestinations = dynamicDestinations; this.createDisposition = createDisposition; this.kmsKey = kmsKey; @@ -246,6 +248,7 @@ public StorageApiWritesShardedRecords( this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.bigLakeConfiguration = bigLakeConfiguration; } @Override @@ -499,7 +502,8 @@ public void process( createDisposition, destinationCoder, kmsKey, - bqServices); + bqServices, + bigLakeConfiguration); return true; }; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java index 7872c91d1f72..8899ac82eb06 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java @@ -25,7 +25,6 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.schemas.Schema; @@ -97,20 +96,22 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { return PCollectionRowTuple.empty(input.getPipeline()); } - BigQueryIO.Write toWrite(Schema schema, PipelineOptions options) { + @VisibleForTesting + public BigQueryIO.Write toWrite(Schema schema, PipelineOptions options) { PortableBigQueryDestinations dynamicDestinations = new PortableBigQueryDestinations(schema, configuration); BigQueryIO.Write write = BigQueryIO.write() .to(dynamicDestinations) .withMethod(BigQueryIO.Write.Method.FILE_LOADS) - .withFormatFunction(BigQueryUtils.toTableRow()) // TODO(https://github.com/apache/beam/issues/33074) BatchLoad's // createTempFilePrefixView() doesn't pick up the pipeline option .withCustomGcsTempLocation( ValueProvider.StaticValueProvider.of(options.getTempLocation())) .withWriteDisposition(WriteDisposition.WRITE_APPEND) - .withFormatFunction(dynamicDestinations.getFilterFormatFunction(false)); + // Use Avro format for better performance. Don't change this unless it's for a good + // reason. + .withAvroFormatFunction(dynamicDestinations.getAvroFilterFormatFunction(false)); if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java index 54d125012eac..0cd2b65b0858 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java @@ -25,7 +25,10 @@ import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; import java.util.List; +import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.io.gcp.bigquery.AvroWriteRequest; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; @@ -102,4 +105,16 @@ public SerializableFunction getFilterFormatFunction(boolean fetch return BigQueryUtils.toTableRow(filtered); }; } + + public SerializableFunction, GenericRecord> getAvroFilterFormatFunction( + boolean fetchNestedRecord) { + return request -> { + Row row = request.getElement(); + if (fetchNestedRecord) { + row = checkStateNotNull(row.getRow(RECORD)); + } + Row filtered = rowFilter.filter(row); + return AvroUtils.toGenericRecord(filtered); + }; + } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java index 37d0227555b7..e9d0709343f5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java @@ -1197,7 +1197,7 @@ static final class PartitionQueryResponseToRunQueryRequest .filter( v -> { String referenceValue = v.getReferenceValue(); - return referenceValue != null && !referenceValue.isEmpty(); + return !referenceValue.isEmpty(); }) .findFirst(); Function stringToPath = (String s) -> s.split("/"); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java index 0cfb06688108..de6c14722898 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java @@ -269,13 +269,12 @@ public List pull( List incomingMessages = new ArrayList<>(response.getReceivedMessagesCount()); for (ReceivedMessage message : response.getReceivedMessagesList()) { PubsubMessage pubsubMessage = message.getMessage(); - @Nullable Map attributes = pubsubMessage.getAttributes(); + Map attributes = pubsubMessage.getAttributes(); // Timestamp. long timestampMsSinceEpoch; if (Strings.isNullOrEmpty(timestampAttribute)) { Timestamp timestampProto = pubsubMessage.getPublishTime(); - checkArgument(timestampProto != null, "Pubsub message is missing timestamp proto"); timestampMsSinceEpoch = timestampProto.getSeconds() * 1000 + timestampProto.getNanos() / 1000L / 1000L; } else { @@ -288,7 +287,7 @@ public List pull( // Record id, if any. @Nullable String recordId = null; - if (idAttribute != null && attributes != null) { + if (idAttribute != null) { recordId = attributes.get(idAttribute); } if (Strings.isNullOrEmpty(recordId)) { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java index 0a838da66f69..a12c64ff9a9b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java @@ -158,12 +158,7 @@ public int publish(TopicPath topic, List outgoingMessages) thro } private Map getMessageAttributes(OutgoingMessage outgoingMessage) { - Map attributes = null; - if (outgoingMessage.getMessage().getAttributesMap() == null) { - attributes = new TreeMap<>(); - } else { - attributes = new TreeMap<>(outgoingMessage.getMessage().getAttributesMap()); - } + Map attributes = new TreeMap<>(outgoingMessage.getMessage().getAttributesMap()); if (timestampAttribute != null) { attributes.put( timestampAttribute, String.valueOf(outgoingMessage.getTimestampMsSinceEpoch())); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java index 76440b1ebf1a..0bcf6e0c4f75 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java @@ -76,40 +76,34 @@ public String description() { + "\n" + "Example configuration for performing a read using a SQL query: ::\n" + "\n" - + " pipeline:\n" - + " transforms:\n" - + " - type: ReadFromSpanner\n" - + " config:\n" - + " instance_id: 'my-instance-id'\n" - + " database_id: 'my-database'\n" - + " query: 'SELECT * FROM table'\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " query: 'SELECT * FROM table'\n" + "\n" + "It is also possible to read a table by specifying a table name and a list of columns. For " + "example, the following configuration will perform a read on an entire table: ::\n" + "\n" - + " pipeline:\n" - + " transforms:\n" - + " - type: ReadFromSpanner\n" - + " config:\n" - + " instance_id: 'my-instance-id'\n" - + " database_id: 'my-database'\n" - + " table: 'my-table'\n" - + " columns: ['col1', 'col2']\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + " columns: ['col1', 'col2']\n" + "\n" + "Additionally, to read using a " + "Secondary Index, specify the index name: ::" + "\n" - + " pipeline:\n" - + " transforms:\n" - + " - type: ReadFromSpanner\n" - + " config:\n" - + " instance_id: 'my-instance-id'\n" - + " database_id: 'my-database'\n" - + " table: 'my-table'\n" - + " index: 'my-index'\n" - + " columns: ['col1', 'col2']\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + " index: 'my-index'\n" + + " columns: ['col1', 'col2']\n" + "\n" - + "### Advanced Usage\n" + + "#### Advanced Usage\n" + "\n" + "Reads by default use the " + "PartitionQuery API which enforces some limitations on the type of queries that can be used so that " @@ -118,12 +112,10 @@ public String description() { + "\n" + "For example: ::" + "\n" - + " pipeline:\n" - + " transforms:\n" - + " - type: ReadFromSpanner\n" - + " config:\n" - + " batching: false\n" - + " ...\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " batching: false\n" + + " ...\n" + "\n" + "Note: See " diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java index 8601da09ea09..61955f448c3f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java @@ -84,14 +84,12 @@ public String description() { + "\n" + "Example configuration for performing a write to a single table: ::\n" + "\n" - + " pipeline:\n" - + " transforms:\n" - + " - type: ReadFromSpanner\n" - + " config:\n" - + " project_id: 'my-project-id'\n" - + " instance_id: 'my-instance-id'\n" - + " database_id: 'my-database'\n" - + " table: 'my-table'\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " project_id: 'my-project-id'\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + "\n" + "Note: See " diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/ChildPartitionsRecordAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/ChildPartitionsRecordAction.java index ada794d20c3b..14b6b2e2453a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/ChildPartitionsRecordAction.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/ChildPartitionsRecordAction.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ChildPartition; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ChildPartitionsRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.RestrictionInterrupter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.TimestampRange; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; @@ -94,17 +95,21 @@ public class ChildPartitionsRecordAction { * @param record the change stream child partition record received * @param tracker the restriction tracker of the {@link * org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.ReadChangeStreamPartitionDoFn} SDF + * @param interrupter the restriction interrupter suggesting early termination of the processing * @param watermarkEstimator the watermark estimator of the {@link * org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.ReadChangeStreamPartitionDoFn} SDF * @return {@link Optional#empty()} if the caller can continue processing more records. A non * empty {@link Optional} with {@link ProcessContinuation#stop()} if this function was unable - * to claim the {@link ChildPartitionsRecord} timestamp + * to claim the {@link ChildPartitionsRecord} timestamp. A non empty {@link Optional} with + * {@link ProcessContinuation#resume()} if this function should commit what has already been + * processed and resume. */ @VisibleForTesting public Optional run( PartitionMetadata partition, ChildPartitionsRecord record, RestrictionTracker tracker, + RestrictionInterrupter interrupter, ManualWatermarkEstimator watermarkEstimator) { final String token = partition.getPartitionToken(); @@ -113,6 +118,13 @@ public Optional run( final Timestamp startTimestamp = record.getStartTimestamp(); final Instant startInstant = new Instant(startTimestamp.toSqlTimestamp().getTime()); + if (interrupter.tryInterrupt(startTimestamp)) { + LOG.debug( + "[{}] Soft deadline reached with child partitions record at {}, rescheduling", + token, + startTimestamp); + return Optional.of(ProcessContinuation.resume()); + } if (!tracker.tryClaim(startTimestamp)) { LOG.debug("[{}] Could not claim queryChangeStream({}), stopping", token, startTimestamp); return Optional.of(ProcessContinuation.stop()); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/DataChangeRecordAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/DataChangeRecordAction.java index 4ceda8afb3e6..555b1fefbebc 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/DataChangeRecordAction.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/DataChangeRecordAction.java @@ -21,9 +21,9 @@ import java.util.Optional; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.ReadChangeStreamPartitionDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.estimator.ThroughputEstimator; -import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ChildPartitionsRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.RestrictionInterrupter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.TimestampRange; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; @@ -68,18 +68,22 @@ public DataChangeRecordAction(ThroughputEstimator throughputEs * @param partition the current partition being processed * @param record the change stream data record received * @param tracker the restriction tracker of the {@link ReadChangeStreamPartitionDoFn} SDF + * @param interrupter the restriction interrupter suggesting early termination of the processing * @param outputReceiver the output receiver of the {@link ReadChangeStreamPartitionDoFn} SDF * @param watermarkEstimator the watermark estimator of the {@link ReadChangeStreamPartitionDoFn} * SDF * @return {@link Optional#empty()} if the caller can continue processing more records. A non * empty {@link Optional} with {@link ProcessContinuation#stop()} if this function was unable - * to claim the {@link ChildPartitionsRecord} timestamp + * to claim the {@link DataChangeRecord} timestamp. A non empty {@link Optional} with {@link + * ProcessContinuation#resume()} if this function should commit what has already been + * processed and resume. */ @VisibleForTesting public Optional run( PartitionMetadata partition, DataChangeRecord record, RestrictionTracker tracker, + RestrictionInterrupter interrupter, OutputReceiver outputReceiver, ManualWatermarkEstimator watermarkEstimator) { @@ -88,6 +92,13 @@ public Optional run( final Timestamp commitTimestamp = record.getCommitTimestamp(); final Instant commitInstant = new Instant(commitTimestamp.toSqlTimestamp().getTime()); + if (interrupter.tryInterrupt(commitTimestamp)) { + LOG.debug( + "[{}] Soft deadline reached with data change record at {}, rescheduling", + token, + commitTimestamp); + return Optional.of(ProcessContinuation.resume()); + } if (!tracker.tryClaim(commitTimestamp)) { LOG.debug("[{}] Could not claim queryChangeStream({}), stopping", token, commitTimestamp); return Optional.of(ProcessContinuation.stop()); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/HeartbeatRecordAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/HeartbeatRecordAction.java index 83a232fe2093..0937e896fbf1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/HeartbeatRecordAction.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/HeartbeatRecordAction.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamMetrics; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.HeartbeatRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.RestrictionInterrupter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.TimestampRange; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; @@ -56,7 +57,9 @@ public class HeartbeatRecordAction { * not. If the {@link Optional} returned is empty, it means that the calling function can continue * with the processing. If an {@link Optional} of {@link ProcessContinuation#stop()} is returned, * it means that this function was unable to claim the timestamp of the {@link HeartbeatRecord}, - * so the caller should stop. + * so the caller should stop. If an {@link Optional} of {@link ProcessContinuation#resume()} is + * returned, it means that this function should not attempt to claim further timestamps of the + * {@link HeartbeatRecord}, but instead should commit what it has processed so far. * *

When processing the {@link HeartbeatRecord} the following procedure is applied: * @@ -72,6 +75,7 @@ public Optional run( PartitionMetadata partition, HeartbeatRecord record, RestrictionTracker tracker, + RestrictionInterrupter interrupter, ManualWatermarkEstimator watermarkEstimator) { final String token = partition.getPartitionToken(); @@ -79,6 +83,11 @@ public Optional run( final Timestamp timestamp = record.getTimestamp(); final Instant timestampInstant = new Instant(timestamp.toSqlTimestamp().getTime()); + if (interrupter.tryInterrupt(timestamp)) { + LOG.debug( + "[{}] Soft deadline reached with heartbeat record at {}, rescheduling", token, timestamp); + return Optional.of(ProcessContinuation.resume()); + } if (!tracker.tryClaim(timestamp)) { LOG.debug("[{}] Could not claim queryChangeStream({}), stopping", token, timestamp); return Optional.of(ProcessContinuation.stop()); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/QueryChangeStreamAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/QueryChangeStreamAction.java index 92285946e56f..6edbd544a37c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/QueryChangeStreamAction.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/QueryChangeStreamAction.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.HeartbeatRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.RestrictionInterrupter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.TimestampRange; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; @@ -62,6 +63,13 @@ public class QueryChangeStreamAction { private static final Logger LOG = LoggerFactory.getLogger(QueryChangeStreamAction.class); private static final Duration BUNDLE_FINALIZER_TIMEOUT = Duration.standardMinutes(5); + /* + * Corresponds to the best effort timeout in case the restriction tracker cannot split the processing + * interval before the hard deadline. When reached it will assure that the already processed timestamps + * will be committed instead of thrown away (DEADLINE_EXCEEDED). The value should be less than + * the RetrySetting RPC timeout setting of SpannerIO#ReadChangeStream. + */ + private static final Duration RESTRICTION_TRACKER_TIMEOUT = Duration.standardSeconds(40); private static final String OUT_OF_RANGE_ERROR_MESSAGE = "Specified start_timestamp is invalid"; private final ChangeStreamDao changeStreamDao; @@ -164,6 +172,10 @@ public ProcessContinuation run( new IllegalStateException( "Partition " + token + " not found in metadata table")); + // Interrupter with soft timeout to commit the work if any records have been processed. + RestrictionInterrupter interrupter = + RestrictionInterrupter.withSoftTimeout(RESTRICTION_TRACKER_TIMEOUT); + try (ChangeStreamResultSet resultSet = changeStreamDao.changeStreamQuery( token, startTimestamp, endTimestamp, partition.getHeartbeatMillis())) { @@ -182,16 +194,25 @@ public ProcessContinuation run( updatedPartition, (DataChangeRecord) record, tracker, + interrupter, receiver, watermarkEstimator); } else if (record instanceof HeartbeatRecord) { maybeContinuation = heartbeatRecordAction.run( - updatedPartition, (HeartbeatRecord) record, tracker, watermarkEstimator); + updatedPartition, + (HeartbeatRecord) record, + tracker, + interrupter, + watermarkEstimator); } else if (record instanceof ChildPartitionsRecord) { maybeContinuation = childPartitionsRecordAction.run( - updatedPartition, (ChildPartitionsRecord) record, tracker, watermarkEstimator); + updatedPartition, + (ChildPartitionsRecord) record, + tracker, + interrupter, + watermarkEstimator); } else { LOG.error("[{}] Unknown record type {}", token, record.getClass()); throw new IllegalArgumentException("Unknown record type " + record.getClass()); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/restriction/RestrictionInterrupter.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/restriction/RestrictionInterrupter.java new file mode 100644 index 000000000000..37e91911867a --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/restriction/RestrictionInterrupter.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction; + +import java.util.function.Supplier; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; + +/** An interrupter for restriction tracker of type T. */ +public class RestrictionInterrupter { + private @Nullable T lastAttemptedPosition; + + private Supplier timeSupplier; + private final Instant softDeadline; + private boolean hasInterrupted = true; + + /** + * Sets a soft timeout from now for processing new positions. After the timeout the tryInterrupt + * will start returning true indicating an early exit from processing. + */ + public static RestrictionInterrupter withSoftTimeout(Duration timeout) { + return new RestrictionInterrupter(() -> Instant.now(), timeout); + } + + RestrictionInterrupter(Supplier timeSupplier, Duration timeout) { + this.timeSupplier = timeSupplier; + this.softDeadline = this.timeSupplier.get().plus(timeout); + hasInterrupted = false; + } + + @VisibleForTesting + void setTimeSupplier(Supplier timeSupplier) { + this.timeSupplier = timeSupplier; + } + + /** + * Returns true if the restriction tracker should be interrupted in claiming new positions. + * + *

    + *
  1. If soft deadline hasn't been reached always returns false. + *
  2. If soft deadline has been reached but we haven't processed any positions returns false. + *
  3. If soft deadline has been reached but the new position is the same as the last attempted + * position returns false. + *
  4. If soft deadline has been reached and the new position differs from the last attempted + * position returns true. + *
+ * + * @return {@code true} if the position processing should continue, {@code false} if the soft + * deadline has been reached and we have fully processed the previous position. + */ + public boolean tryInterrupt(@NonNull T position) { + if (hasInterrupted) { + return true; + } + if (lastAttemptedPosition == null) { + lastAttemptedPosition = position; + return false; + } + if (position.equals(lastAttemptedPosition)) { + return false; + } + lastAttemptedPosition = position; + + hasInterrupted |= timeSupplier.get().isAfter(softDeadline); + return hasInterrupted; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/AvroGenericRecordToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/AvroGenericRecordToStorageApiProtoTest.java index 6a59afeed823..472173c67412 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/AvroGenericRecordToStorageApiProtoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/AvroGenericRecordToStorageApiProtoTest.java @@ -28,6 +28,7 @@ import com.google.protobuf.Descriptors; import com.google.protobuf.DynamicMessage; import java.math.BigDecimal; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -308,7 +309,7 @@ enum TestEnum { Instant now = Instant.now(); baseRecord = new GenericRecordBuilder(BASE_SCHEMA) - .set("bytesValue", BYTES) + .set("bytesValue", ByteBuffer.wrap(BYTES)) .set("intValue", (int) 3) .set("longValue", (long) 4) .set("floatValue", (float) 3.14) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java index 662f2658eb6b..2333278a11f5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java @@ -28,23 +28,23 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import java.util.function.Function; import org.apache.avro.Conversions; import org.apache.avro.LogicalType; import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; -import org.apache.avro.Schema.Field; -import org.apache.avro.Schema.Type; +import org.apache.avro.SchemaBuilder; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; -import org.apache.avro.reflect.AvroSchema; -import org.apache.avro.reflect.Nullable; -import org.apache.avro.reflect.ReflectData; +import org.apache.avro.generic.GenericRecordBuilder; import org.apache.avro.util.Utf8; -import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding; import org.junit.Test; @@ -54,363 +54,678 @@ /** Tests for {@link BigQueryAvroUtils}. */ @RunWith(JUnit4.class) public class BigQueryAvroUtilsTest { - private List subFields = - Lists.newArrayList( - new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE")); - /* - * Note that the quality and quantity fields do not have their mode set, so they should default - * to NULLABLE. This is an important test of BigQuery semantics. - * - * All the other fields we set in this function are required on the Schema response. - * - * See https://cloud.google.com/bigquery/docs/reference/v2/tables#schema - */ - private List fields = - Lists.newArrayList( - new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"), - new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE"), - new TableFieldSchema().setName("quality").setType("FLOAT") /* default to NULLABLE */, - new TableFieldSchema().setName("quantity").setType("INTEGER") /* default to NULLABLE */, - new TableFieldSchema().setName("birthday").setType("TIMESTAMP").setMode("NULLABLE"), - new TableFieldSchema().setName("birthdayMoney").setType("NUMERIC").setMode("NULLABLE"), - new TableFieldSchema() - .setName("lotteryWinnings") - .setType("BIGNUMERIC") - .setMode("NULLABLE"), - new TableFieldSchema().setName("flighted").setType("BOOLEAN").setMode("NULLABLE"), - new TableFieldSchema().setName("sound").setType("BYTES").setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryDate").setType("DATE").setMode("NULLABLE"), - new TableFieldSchema() - .setName("anniversaryDatetime") - .setType("DATETIME") - .setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryTime").setType("TIME").setMode("NULLABLE"), - new TableFieldSchema() - .setName("scion") - .setType("RECORD") - .setMode("NULLABLE") - .setFields(subFields), - new TableFieldSchema() - .setName("associates") - .setType("RECORD") - .setMode("REPEATED") - .setFields(subFields), - new TableFieldSchema().setName("geoPositions").setType("GEOGRAPHY").setMode("NULLABLE")); - - private ByteBuffer convertToBytes(BigDecimal bigDecimal, int precision, int scale) { - LogicalType bigDecimalLogicalType = LogicalTypes.decimal(precision, scale); - return new Conversions.DecimalConversion().toBytes(bigDecimal, null, bigDecimalLogicalType); + + private TableSchema tableSchema(Function fn) { + TableFieldSchema column = new TableFieldSchema().setName("value"); + TableSchema tableSchema = new TableSchema(); + tableSchema.setFields(Lists.newArrayList(fn.apply(column))); + return tableSchema; + } + + private Schema avroSchema( + Function, SchemaBuilder.FieldAssembler> fn) { + return fn.apply( + SchemaBuilder.record("root") + .namespace("org.apache.beam.sdk.io.gcp.bigquery") + .doc("Translated Avro Schema for root") + .fields() + .name("value")) + .endRecord(); } + @SuppressWarnings("JavaInstantGetSecondsGetNano") @Test - public void testConvertGenericRecordToTableRow() throws Exception { - BigDecimal numeric = new BigDecimal("123456789.123456789"); - ByteBuffer numericBytes = convertToBytes(numeric, 38, 9); - BigDecimal bigNumeric = - new BigDecimal( - "578960446186580977117854925043439539266.34992332820282019728792003956564819967"); - ByteBuffer bigNumericBytes = convertToBytes(bigNumeric, 77, 38); - Schema avroSchema = ReflectData.get().getSchema(Bird.class); - - { - // Test nullable fields. - GenericRecord record = new GenericData.Record(avroSchema); - record.put("number", 5L); - TableRow convertedRow = BigQueryAvroUtils.convertGenericRecordToTableRow(record); - TableRow row = new TableRow().set("number", "5").set("associates", new ArrayList()); - assertEquals(row, convertedRow); - TableRow clonedRow = convertedRow.clone(); - assertEquals(convertedRow, clonedRow); - } - { - // Test type conversion for: - // INTEGER, FLOAT, NUMERIC, TIMESTAMP, BOOLEAN, BYTES, DATE, DATETIME, TIME. - GenericRecord record = new GenericData.Record(avroSchema); - byte[] soundBytes = "chirp,chirp".getBytes(StandardCharsets.UTF_8); - ByteBuffer soundByteBuffer = ByteBuffer.wrap(soundBytes); - soundByteBuffer.rewind(); - record.put("number", 5L); - record.put("quality", 5.0); - record.put("birthday", 5L); - record.put("birthdayMoney", numericBytes); - record.put("lotteryWinnings", bigNumericBytes); - record.put("flighted", Boolean.TRUE); - record.put("sound", soundByteBuffer); - record.put("anniversaryDate", new Utf8("2000-01-01")); - record.put("anniversaryDatetime", new String("2000-01-01 00:00:00.000005")); - record.put("anniversaryTime", new Utf8("00:00:00.000005")); - record.put("geoPositions", new String("LINESTRING(1 2, 3 4, 5 6, 7 8)")); - TableRow convertedRow = BigQueryAvroUtils.convertGenericRecordToTableRow(record); - TableRow row = - new TableRow() - .set("number", "5") - .set("birthday", "1970-01-01 00:00:00.000005 UTC") - .set("birthdayMoney", numeric.toString()) - .set("lotteryWinnings", bigNumeric.toString()) - .set("quality", 5.0) - .set("associates", new ArrayList()) - .set("flighted", Boolean.TRUE) - .set("sound", BaseEncoding.base64().encode(soundBytes)) - .set("anniversaryDate", "2000-01-01") - .set("anniversaryDatetime", "2000-01-01 00:00:00.000005") - .set("anniversaryTime", "00:00:00.000005") - .set("geoPositions", "LINESTRING(1 2, 3 4, 5 6, 7 8)"); - TableRow clonedRow = convertedRow.clone(); - assertEquals(convertedRow, clonedRow); - assertEquals(row, convertedRow); - } - { - // Test repeated fields. - Schema subBirdSchema = AvroCoder.of(Bird.SubBird.class).getSchema(); - GenericRecord nestedRecord = new GenericData.Record(subBirdSchema); - nestedRecord.put("species", "other"); - GenericRecord record = new GenericData.Record(avroSchema); - record.put("number", 5L); - record.put("associates", Lists.newArrayList(nestedRecord)); - record.put("birthdayMoney", numericBytes); - record.put("lotteryWinnings", bigNumericBytes); - TableRow convertedRow = BigQueryAvroUtils.convertGenericRecordToTableRow(record); - TableRow row = + public void testConvertGenericRecordToTableRow() { + { + // bool + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().booleanType().noDefault())) + .set("value", false) + .build(); + TableRow expected = new TableRow().set("value", false); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // int + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().intType().noDefault())) + .set("value", 5) + .build(); + TableRow expected = new TableRow().set("value", "5"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // long + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().longType().noDefault())) + .set("value", 5L) + .build(); + TableRow expected = new TableRow().set("value", "5"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // float + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().floatType().noDefault())) + .set("value", 5.5f) + .build(); + TableRow expected = new TableRow().set("value", 5.5); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // double + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().doubleType().noDefault())) + .set("value", 5.55) + .build(); + TableRow expected = new TableRow().set("value", 5.55); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // bytes + byte[] bytes = "chirp,chirp".getBytes(StandardCharsets.UTF_8); + ByteBuffer bb = ByteBuffer.wrap(bytes); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().bytesType().noDefault())) + .set("value", bb) + .build(); + TableRow expected = new TableRow().set("value", BaseEncoding.base64().encode(bytes)); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // string + Schema schema = avroSchema(f -> f.type().stringType().noDefault()); + GenericRecord record = new GenericRecordBuilder(schema).set("value", "test").build(); + GenericRecord utf8Record = + new GenericRecordBuilder(schema).set("value", new Utf8("test")).build(); + TableRow expected = new TableRow().set("value", "test"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + TableRow utf8Row = BigQueryAvroUtils.convertGenericRecordToTableRow(utf8Record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + assertEquals(expected, utf8Row); + assertEquals(expected, utf8Row.clone()); + } + + { + // decimal + LogicalType lt = LogicalTypes.decimal(38, 9); + Schema decimalType = lt.addToSchema(SchemaBuilder.builder().bytesType()); + BigDecimal bd = new BigDecimal("123456789.123456789"); + ByteBuffer bytes = new Conversions.DecimalConversion().toBytes(bd, null, lt); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(decimalType).noDefault())) + .set("value", bytes) + .build(); + TableRow expected = new TableRow().set("value", bd.toString()); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // date + LogicalType lt = LogicalTypes.date(); + Schema dateType = lt.addToSchema(SchemaBuilder.builder().intType()); + LocalDate date = LocalDate.of(2000, 1, 1); + int days = (int) date.toEpochDay(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(dateType).noDefault())) + .set("value", days) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // time-millis + LogicalType lt = LogicalTypes.timeMillis(); + Schema timeType = lt.addToSchema(SchemaBuilder.builder().intType()); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + int millis = (int) (time.toNanoOfDay() / 1000000); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timeType).noDefault())) + .set("value", millis) + .build(); + TableRow expected = new TableRow().set("value", "01:02:03.123"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // time-micros + LogicalType lt = LogicalTypes.timeMicros(); + Schema timeType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + long micros = time.toNanoOfDay() / 1000; + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timeType).noDefault())) + .set("value", micros) + .build(); + TableRow expected = new TableRow().set("value", "01:02:03.123456"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // local-timestamp-millis + LogicalType lt = LogicalTypes.localTimestampMillis(); + Schema timestampType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalDate date = LocalDate.of(2000, 1, 1); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + LocalDateTime ts = LocalDateTime.of(date, time); + long millis = ts.toInstant(ZoneOffset.UTC).toEpochMilli(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timestampType).noDefault())) + .set("value", millis) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01 01:02:03.123"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // local-timestamp-micros + LogicalType lt = LogicalTypes.localTimestampMicros(); + Schema timestampType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalDate date = LocalDate.of(2000, 1, 1); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + LocalDateTime ts = LocalDateTime.of(date, time); + long seconds = ts.toInstant(ZoneOffset.UTC).getEpochSecond(); + int nanos = ts.toInstant(ZoneOffset.UTC).getNano(); + long micros = seconds * 1000000 + (nanos / 1000); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timestampType).noDefault())) + .set("value", micros) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01 01:02:03.123456"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // timestamp-micros + LogicalType lt = LogicalTypes.timestampMillis(); + Schema timestampType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalDate date = LocalDate.of(2000, 1, 1); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + LocalDateTime ts = LocalDateTime.of(date, time); + long millis = ts.toInstant(ZoneOffset.UTC).toEpochMilli(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timestampType).noDefault())) + .set("value", millis) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01 01:02:03.123 UTC"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // timestamp-millis + LogicalType lt = LogicalTypes.timestampMicros(); + Schema timestampType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalDate date = LocalDate.of(2000, 1, 1); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + LocalDateTime ts = LocalDateTime.of(date, time); + long seconds = ts.toInstant(ZoneOffset.UTC).getEpochSecond(); + int nanos = ts.toInstant(ZoneOffset.UTC).getNano(); + long micros = seconds * 1000000 + (nanos / 1000); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timestampType).noDefault())) + .set("value", micros) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01 01:02:03.123456 UTC"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // enum + Schema enumSchema = SchemaBuilder.enumeration("color").symbols("red", "green", "blue"); + GenericData.EnumSymbol symbol = new GenericData.EnumSymbol(enumSchema, "RED"); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(enumSchema).noDefault())) + .set("value", symbol) + .build(); + TableRow expected = new TableRow().set("value", "RED"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // fixed + UUID uuid = UUID.randomUUID(); + ByteBuffer bb = ByteBuffer.allocate(16); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + bb.rewind(); + byte[] bytes = bb.array(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().fixed("uuid").size(16).noDefault())) + .set("value", bb) + .build(); + TableRow expected = new TableRow().set("value", BaseEncoding.base64().encode(bytes)); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // null + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().optional().booleanType())).build(); + TableRow expected = new TableRow(); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // array + GenericRecord record = + new GenericRecordBuilder( + avroSchema(f -> f.type().array().items().booleanType().noDefault())) + .set("value", Lists.newArrayList(true, false)) + .build(); + TableRow expected = new TableRow().set("value", Lists.newArrayList(true, false)); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // map + Map map = new HashMap<>(); + map.put("left", 1); + map.put("right", -1); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().map().values().intType().noDefault())) + .set("value", map) + .build(); + TableRow expected = new TableRow() - .set("associates", Lists.newArrayList(new TableRow().set("species", "other"))) - .set("number", "5") - .set("birthdayMoney", numeric.toString()) - .set("lotteryWinnings", bigNumeric.toString()); - assertEquals(row, convertedRow); - TableRow clonedRow = convertedRow.clone(); - assertEquals(convertedRow, clonedRow); + .set( + "value", + Lists.newArrayList( + new TableRow().set("key", "left").set("value", "1"), + new TableRow().set("key", "right").set("value", "-1"))); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // record + Schema subSchema = + SchemaBuilder.builder() + .record("record") + .fields() + .name("int") + .type() + .intType() + .noDefault() + .name("float") + .type() + .floatType() + .noDefault() + .endRecord(); + GenericRecord subRecord = + new GenericRecordBuilder(subSchema).set("int", 5).set("float", 5.5f).build(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(subSchema).noDefault())) + .set("value", subRecord) + .build(); + TableRow expected = + new TableRow().set("value", new TableRow().set("int", "5").set("float", 5.5)); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); } } @Test public void testConvertBigQuerySchemaToAvroSchema() { - TableSchema tableSchema = new TableSchema(); - tableSchema.setFields(fields); - Schema avroSchema = BigQueryAvroUtils.toGenericAvroSchema(tableSchema); + { + // REQUIRED + TableSchema tableSchema = tableSchema(f -> f.setType("BOOLEAN").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().booleanType().noDefault()); - assertThat(avroSchema.getField("number").schema(), equalTo(Schema.create(Type.LONG))); - assertThat( - avroSchema.getField("species").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); - assertThat( - avroSchema.getField("quality").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.DOUBLE)))); - assertThat( - avroSchema.getField("quantity").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.LONG)))); - assertThat( - avroSchema.getField("birthday").schema(), - equalTo( - Schema.createUnion( - Schema.create(Type.NULL), - LogicalTypes.timestampMicros().addToSchema(Schema.create(Type.LONG))))); - assertThat( - avroSchema.getField("birthdayMoney").schema(), - equalTo( - Schema.createUnion( - Schema.create(Type.NULL), - LogicalTypes.decimal(38, 9).addToSchema(Schema.create(Type.BYTES))))); - assertThat( - avroSchema.getField("lotteryWinnings").schema(), - equalTo( - Schema.createUnion( - Schema.create(Type.NULL), - LogicalTypes.decimal(77, 38).addToSchema(Schema.create(Type.BYTES))))); - assertThat( - avroSchema.getField("flighted").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.BOOLEAN)))); - assertThat( - avroSchema.getField("sound").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.BYTES)))); - Schema dateSchema = Schema.create(Type.INT); - LogicalTypes.date().addToSchema(dateSchema); - assertThat( - avroSchema.getField("anniversaryDate").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), dateSchema))); - Schema dateTimeSchema = Schema.create(Type.STRING); - BigQueryAvroUtils.DATETIME_LOGICAL_TYPE.addToSchema(dateTimeSchema); - assertThat( - avroSchema.getField("anniversaryDatetime").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), dateTimeSchema))); - Schema timeSchema = Schema.create(Type.LONG); - LogicalTypes.timeMicros().addToSchema(timeSchema); - assertThat( - avroSchema.getField("anniversaryTime").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), timeSchema))); - Schema geoSchema = Schema.create(Type.STRING); - geoSchema.addProp("sqlType", "GEOGRAPHY"); - assertThat( - avroSchema.getField("geoPositions").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), geoSchema))); - assertThat( - avroSchema.getField("scion").schema(), - equalTo( - Schema.createUnion( - Schema.create(Type.NULL), - Schema.createRecord( - "scion", - "Translated Avro Schema for scion", - "org.apache.beam.sdk.io.gcp.bigquery", - false, - ImmutableList.of( - new Field( - "species", - Schema.createUnion( - Schema.create(Type.NULL), Schema.create(Type.STRING)), - null, - (Object) null)))))); - assertThat( - avroSchema.getField("associates").schema(), - equalTo( - Schema.createArray( - Schema.createRecord( - "associates", - "Translated Avro Schema for associates", - "org.apache.beam.sdk.io.gcp.bigquery", - false, - ImmutableList.of( - new Field( - "species", - Schema.createUnion( - Schema.create(Type.NULL), Schema.create(Type.STRING)), - null, - (Object) null)))))); - } + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } - @Test - public void testConvertBigQuerySchemaToAvroSchemaWithoutLogicalTypes() { - TableSchema tableSchema = new TableSchema(); - tableSchema.setFields(fields); - Schema avroSchema = BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false); + { + // NULLABLE + TableSchema tableSchema = tableSchema(f -> f.setType("BOOLEAN").setMode("NULLABLE")); + Schema expected = + avroSchema(f -> f.type().unionOf().nullType().and().booleanType().endUnion().noDefault()); - assertThat(avroSchema.getField("number").schema(), equalTo(Schema.create(Schema.Type.LONG))); - assertThat( - avroSchema.getField("species").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.STRING)))); - assertThat( - avroSchema.getField("quality").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.DOUBLE)))); - assertThat( - avroSchema.getField("quantity").schema(), - equalTo( - Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.LONG)))); - assertThat( - avroSchema.getField("birthday").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), - LogicalTypes.timestampMicros().addToSchema(Schema.create(Schema.Type.LONG))))); - assertThat( - avroSchema.getField("birthdayMoney").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), - LogicalTypes.decimal(38, 9).addToSchema(Schema.create(Schema.Type.BYTES))))); - assertThat( - avroSchema.getField("lotteryWinnings").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), - LogicalTypes.decimal(77, 38).addToSchema(Schema.create(Schema.Type.BYTES))))); - assertThat( - avroSchema.getField("flighted").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.BOOLEAN)))); - assertThat( - avroSchema.getField("sound").schema(), - equalTo( - Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.BYTES)))); - Schema dateSchema = Schema.create(Schema.Type.STRING); - dateSchema.addProp("sqlType", "DATE"); - assertThat( - avroSchema.getField("anniversaryDate").schema(), - equalTo(Schema.createUnion(Schema.create(Schema.Type.NULL), dateSchema))); - Schema dateTimeSchema = Schema.create(Schema.Type.STRING); - dateTimeSchema.addProp("sqlType", "DATETIME"); - assertThat( - avroSchema.getField("anniversaryDatetime").schema(), - equalTo(Schema.createUnion(Schema.create(Schema.Type.NULL), dateTimeSchema))); - Schema timeSchema = Schema.create(Schema.Type.STRING); - timeSchema.addProp("sqlType", "TIME"); - assertThat( - avroSchema.getField("anniversaryTime").schema(), - equalTo(Schema.createUnion(Schema.create(Schema.Type.NULL), timeSchema))); - Schema geoSchema = Schema.create(Type.STRING); - geoSchema.addProp("sqlType", "GEOGRAPHY"); - assertThat( - avroSchema.getField("geoPositions").schema(), - equalTo(Schema.createUnion(Schema.create(Schema.Type.NULL), geoSchema))); - assertThat( - avroSchema.getField("scion").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), - Schema.createRecord( - "scion", - "Translated Avro Schema for scion", - "org.apache.beam.sdk.io.gcp.bigquery", - false, - ImmutableList.of( - new Schema.Field( - "species", - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.STRING)), - null, - (Object) null)))))); - assertThat( - avroSchema.getField("associates").schema(), - equalTo( - Schema.createArray( - Schema.createRecord( - "associates", - "Translated Avro Schema for associates", - "org.apache.beam.sdk.io.gcp.bigquery", - false, - ImmutableList.of( - new Schema.Field( - "species", - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.STRING)), - null, - (Object) null)))))); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } + + { + // default mode -> NULLABLE + TableSchema tableSchema = tableSchema(f -> f.setType("BOOLEAN")); + Schema expected = + avroSchema(f -> f.type().unionOf().nullType().and().booleanType().endUnion().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } + + { + // REPEATED + TableSchema tableSchema = tableSchema(f -> f.setType("BOOLEAN").setMode("REPEATED")); + Schema expected = avroSchema(f -> f.type().array().items().booleanType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } + + { + // INTEGER + TableSchema tableSchema = tableSchema(f -> f.setType("INTEGER").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().longType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } + + { + // FLOAT + TableSchema tableSchema = tableSchema(f -> f.setType("FLOAT").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().doubleType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // BYTES + TableSchema tableSchema = tableSchema(f -> f.setType("BYTES").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().bytesType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // STRING + TableSchema tableSchema = tableSchema(f -> f.setType("STRING").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().stringType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // NUMERIC + TableSchema tableSchema = tableSchema(f -> f.setType("NUMERIC").setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(38, 9).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // NUMERIC with precision + TableSchema tableSchema = + tableSchema(f -> f.setType("NUMERIC").setPrecision(29L).setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(29, 0).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // NUMERIC with precision and scale + TableSchema tableSchema = + tableSchema(f -> f.setType("NUMERIC").setPrecision(10L).setScale(9L).setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(10, 9).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // BIGNUMERIC + TableSchema tableSchema = tableSchema(f -> f.setType("BIGNUMERIC").setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(77, 38).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // BIGNUMERIC with precision + TableSchema tableSchema = + tableSchema(f -> f.setType("BIGNUMERIC").setPrecision(38L).setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(38, 0).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // BIGNUMERIC with precision and scale + TableSchema tableSchema = + tableSchema( + f -> f.setType("BIGNUMERIC").setPrecision(39L).setScale(38L).setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(39, 38).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // DATE + TableSchema tableSchema = tableSchema(f -> f.setType("DATE").setMode("REQUIRED")); + Schema dateType = LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType()); + Schema expected = avroSchema(f -> f.type(dateType).noDefault()); + Schema expectedExport = + avroSchema(f -> f.type().stringBuilder().prop("sqlType", "DATE").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expectedExport, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // TIME + TableSchema tableSchema = tableSchema(f -> f.setType("TIME").setMode("REQUIRED")); + Schema timeType = LogicalTypes.timeMicros().addToSchema(SchemaBuilder.builder().longType()); + Schema expected = avroSchema(f -> f.type(timeType).noDefault()); + Schema expectedExport = + avroSchema(f -> f.type().stringBuilder().prop("sqlType", "TIME").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expectedExport, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // DATETIME + TableSchema tableSchema = tableSchema(f -> f.setType("DATETIME").setMode("REQUIRED")); + Schema timeType = + BigQueryAvroUtils.DATETIME_LOGICAL_TYPE.addToSchema(SchemaBuilder.builder().stringType()); + Schema expected = avroSchema(f -> f.type(timeType).noDefault()); + Schema expectedExport = + avroSchema( + f -> f.type().stringBuilder().prop("sqlType", "DATETIME").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expectedExport, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // TIMESTAMP + TableSchema tableSchema = tableSchema(f -> f.setType("TIMESTAMP").setMode("REQUIRED")); + Schema timestampType = + LogicalTypes.timestampMicros().addToSchema(SchemaBuilder.builder().longType()); + Schema expected = avroSchema(f -> f.type(timestampType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // GEOGRAPHY + TableSchema tableSchema = tableSchema(f -> f.setType("GEOGRAPHY").setMode("REQUIRED")); + Schema expected = + avroSchema( + f -> f.type().stringBuilder().prop("sqlType", "GEOGRAPHY").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // JSON + TableSchema tableSchema = tableSchema(f -> f.setType("JSON").setMode("REQUIRED")); + Schema expected = + avroSchema(f -> f.type().stringBuilder().prop("sqlType", "JSON").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // STRUCT/RECORD + TableFieldSchema subInteger = + new TableFieldSchema().setName("int").setType("INTEGER").setMode("NULLABLE"); + TableFieldSchema subFloat = + new TableFieldSchema().setName("float").setType("FLOAT").setMode("REQUIRED"); + TableSchema structTableSchema = + tableSchema( + f -> + f.setType("STRUCT") + .setMode("REQUIRED") + .setFields(Lists.newArrayList(subInteger, subFloat))); + TableSchema recordTableSchema = + tableSchema( + f -> + f.setType("RECORD") + .setMode("REQUIRED") + .setFields(Lists.newArrayList(subInteger, subFloat))); + + Schema expected = + avroSchema( + f -> + f.type() + .record("value") + .fields() + .name("int") + .type() + .unionOf() + .nullType() + .and() + .longType() + .endUnion() + .noDefault() + .name("float") + .type() + .doubleType() + .noDefault() + .endRecord() + .noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(structTableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(structTableSchema, false)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(recordTableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(recordTableSchema, false)); + } } @Test public void testFormatTimestamp() { - assertThat( - BigQueryAvroUtils.formatTimestamp(1452062291123456L), - equalTo("2016-01-06 06:38:11.123456 UTC")); + long micros = 1452062291123456L; + String expected = "2016-01-06 06:38:11.123456"; + assertThat(BigQueryAvroUtils.formatDatetime(micros), equalTo(expected)); + assertThat(BigQueryAvroUtils.formatTimestamp(micros), equalTo(expected + " UTC")); } @Test - public void testFormatTimestampLeadingZeroesOnMicros() { - assertThat( - BigQueryAvroUtils.formatTimestamp(1452062291000456L), - equalTo("2016-01-06 06:38:11.000456 UTC")); + public void testFormatTimestampMillis() { + long millis = 1452062291123L; + long micros = millis * 1000L; + String expected = "2016-01-06 06:38:11.123"; + assertThat(BigQueryAvroUtils.formatDatetime(micros), equalTo(expected)); + assertThat(BigQueryAvroUtils.formatTimestamp(micros), equalTo(expected + " UTC")); } @Test - public void testFormatTimestampTrailingZeroesOnMicros() { - assertThat( - BigQueryAvroUtils.formatTimestamp(1452062291123000L), - equalTo("2016-01-06 06:38:11.123000 UTC")); + public void testFormatTimestampSeconds() { + long seconds = 1452062291L; + long micros = seconds * 1000L * 1000L; + String expected = "2016-01-06 06:38:11"; + assertThat(BigQueryAvroUtils.formatDatetime(micros), equalTo(expected)); + assertThat(BigQueryAvroUtils.formatTimestamp(micros), equalTo(expected + " UTC")); } @Test public void testFormatTimestampNegative() { - assertThat(BigQueryAvroUtils.formatTimestamp(-1L), equalTo("1969-12-31 23:59:59.999999 UTC")); - assertThat( - BigQueryAvroUtils.formatTimestamp(-100_000L), equalTo("1969-12-31 23:59:59.900000 UTC")); - assertThat(BigQueryAvroUtils.formatTimestamp(-1_000_000L), equalTo("1969-12-31 23:59:59 UTC")); + assertThat(BigQueryAvroUtils.formatDatetime(-1L), equalTo("1969-12-31 23:59:59.999999")); + assertThat(BigQueryAvroUtils.formatDatetime(-100_000L), equalTo("1969-12-31 23:59:59.900")); + assertThat(BigQueryAvroUtils.formatDatetime(-1_000_000L), equalTo("1969-12-31 23:59:59")); // No leap seconds before 1972. 477 leap years from 1 through 1969. assertThat( - BigQueryAvroUtils.formatTimestamp(-(1969L * 365 + 477) * 86400 * 1_000_000), - equalTo("0001-01-01 00:00:00 UTC")); + BigQueryAvroUtils.formatDatetime(-(1969L * 365 + 477) * 86400 * 1_000_000), + equalTo("0001-01-01 00:00:00")); } @Test @@ -501,48 +816,4 @@ public void testSchemaCollisionsInAvroConversion() { String output = BigQueryAvroUtils.toGenericAvroSchema(schema, false).toString(); assertThat(output.length(), greaterThan(0)); } - - /** Pojo class used as the record type in tests. */ - @SuppressWarnings("unused") // Used by Avro reflection. - static class Bird { - long number; - @Nullable String species; - @Nullable Double quality; - @Nullable Long quantity; - - @AvroSchema(value = "[\"null\", {\"type\": \"long\", \"logicalType\": \"timestamp-micros\"}]") - Instant birthday; - - @AvroSchema( - value = - "[\"null\", {\"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 38, \"scale\": 9}]") - BigDecimal birthdayMoney; - - @AvroSchema( - value = - "[\"null\", {\"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 77, \"scale\": 38}]") - BigDecimal lotteryWinnings; - - @AvroSchema(value = "[\"null\", {\"type\": \"string\", \"sqlType\": \"GEOGRAPHY\"}]") - String geoPositions; - - @Nullable Boolean flighted; - @Nullable ByteBuffer sound; - @Nullable Utf8 anniversaryDate; - @Nullable String anniversaryDatetime; - @Nullable Utf8 anniversaryTime; - @Nullable SubBird scion; - SubBird[] associates; - - static class SubBird { - @Nullable String species; - - public SubBird() {} - } - - public Bird() { - associates = new SubBird[1]; - associates[0] = new SubBird(); - } - } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java index e15258e6ab40..5b7b5d473190 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java @@ -96,6 +96,7 @@ public class BigQueryIOTranslationTest { WRITE_TRANSFORM_SCHEMA_MAPPING.put("getWriteDisposition", "write_disposition"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getSchemaUpdateOptions", "schema_update_options"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getTableDescription", "table_description"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getBigLakeConfiguration", "biglake_configuration"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getValidate", "validate"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getBigQueryServices", "bigquery_services"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getMaxFilesPerBundle", "max_files_per_bundle"); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index d96e22f84907..57c71c023fcb 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -19,6 +19,7 @@ import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.toJsonString; import static org.apache.beam.sdk.io.gcp.bigquery.WriteTables.ResultCoder.INSTANCE; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsSchemaTransformProvider.BigQueryFileLoadsSchemaTransform; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -32,6 +33,7 @@ import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -117,11 +119,13 @@ import org.apache.beam.sdk.io.gcp.bigquery.WritePartition.ResultCoder; import org.apache.beam.sdk.io.gcp.bigquery.WriteRename.TempTableCleanupFn; import org.apache.beam.sdk.io.gcp.bigquery.WriteTables.Result; +import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsSchemaTransformProvider; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.schemas.JavaFieldSchema; import org.apache.beam.sdk.schemas.Schema; @@ -818,6 +822,25 @@ public void testStreamingFileLoadsWithAutoSharding() throws Exception { assertEquals(2 * numTables, fakeDatasetService.getInsertCount()); } + @Test + public void testFileLoadSchemaTransformUsesAvroFormat() { + // ensure we are writing with the more performant avro format + assumeTrue(!useStreaming); + assumeTrue(!useStorageApi); + BigQueryFileLoadsSchemaTransformProvider provider = + new BigQueryFileLoadsSchemaTransformProvider(); + Row configuration = + Row.withSchema(provider.configurationSchema()) + .withFieldValue("table", "some-table") + .build(); + BigQueryFileLoadsSchemaTransform schemaTransform = + (BigQueryFileLoadsSchemaTransform) provider.from(configuration); + BigQueryIO.Write write = + schemaTransform.toWrite(Schema.of(), PipelineOptionsFactory.create()); + assertNull(write.getFormatFunction()); + assertNotNull(write.getAvroRowWriterFactory()); + } + @Test public void testBatchFileLoads() throws Exception { assumeTrue(!useStreaming); @@ -2257,6 +2280,40 @@ public void testUpdateTableSchemaNoUnknownValues() throws Exception { p.run(); } + @Test + public void testBigLakeConfigurationFailsForNonStorageApiWrites() { + assumeTrue(!useStorageApi); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage( + "bigLakeConfiguration is only supported when using STORAGE_WRITE_API or STORAGE_API_AT_LEAST_ONCE"); + + p.apply(Create.empty(TableRowJsonCoder.of())) + .apply( + BigQueryIO.writeTableRows() + .to("project-id:dataset-id.table") + .withBigLakeConfiguration( + ImmutableMap.of( + "connectionId", "some-connection", + "storageUri", "gs://bucket")) + .withTestServices(fakeBqServices)); + p.run(); + } + + @Test + public void testBigLakeConfigurationFailsForMissingProperties() { + assumeTrue(useStorageApi); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("bigLakeConfiguration must contain keys 'connectionId' and 'storageUri'"); + + p.apply(Create.empty(TableRowJsonCoder.of())) + .apply( + BigQueryIO.writeTableRows() + .to("project-id:dataset-id.table") + .withBigLakeConfiguration(ImmutableMap.of("connectionId", "some-connection")) + .withTestServices(fakeBqServices)); + p.run(); + } + @SuppressWarnings({"unused"}) static class UpdateTableSchemaDoFn extends DoFn, TableRow> { @TimerId("updateTimer") diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkCreateIfNeededIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkCreateIfNeededIT.java index 18c832f0c54b..858921e19ced 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkCreateIfNeededIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkCreateIfNeededIT.java @@ -17,23 +17,32 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.CONNECTION_ID; +import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.STORAGE_URI; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; import com.google.api.services.bigquery.model.TableFieldSchema; import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.storage.model.Objects; import java.io.IOException; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.LongStream; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.extensions.gcp.options.GcsOptions; +import org.apache.beam.sdk.extensions.gcp.util.GcsUtil; import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hamcrest.Matchers; import org.joda.time.Duration; import org.junit.AfterClass; @@ -57,11 +66,16 @@ public static Iterable data() { private static final Logger LOG = LoggerFactory.getLogger(StorageApiSinkCreateIfNeededIT.class); - private static final BigqueryClient BQ_CLIENT = new BigqueryClient("StorageApiSinkFailedRowsIT"); + private static final BigqueryClient BQ_CLIENT = + new BigqueryClient("StorageApiSinkCreateIfNeededIT"); private static final String PROJECT = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); private static final String BIG_QUERY_DATASET_ID = - "storage_api_sink_failed_rows" + System.nanoTime(); + "storage_api_sink_create_tables_" + System.nanoTime(); + private static final String TEST_CONNECTION_ID = + "projects/apache-beam-testing/locations/us/connections/apache-beam-testing-storageapi-biglake-nodelete"; + private static final String TEST_STORAGE_URI = + "gs://apache-beam-testing-bq-biglake/" + StorageApiSinkCreateIfNeededIT.class.getSimpleName(); private static final List FIELDS = ImmutableList.builder() .add(new TableFieldSchema().setType("STRING").setName("str")) @@ -96,19 +110,55 @@ public void testCreateManyTables() throws IOException, InterruptedException { String table = "table" + System.nanoTime(); String tableSpecBase = PROJECT + "." + BIG_QUERY_DATASET_ID + "." + table; - runPipeline(getMethod(), tableSpecBase, inputs); - assertTablesCreated(tableSpecBase, 100); + runPipeline(getMethod(), tableSpecBase, inputs, null); + assertTablesCreated(tableSpecBase, 100, true); } - private void assertTablesCreated(String tableSpecPrefix, int expectedRows) + @Test + public void testCreateBigLakeTables() throws IOException, InterruptedException { + int numTables = 5; + List inputs = + LongStream.range(0, numTables) + .mapToObj(l -> new TableRow().set("str", "foo").set("tablenum", l)) + .collect(Collectors.toList()); + + String table = "iceberg_table_" + System.nanoTime() + "_"; + String tableSpecBase = PROJECT + "." + BIG_QUERY_DATASET_ID + "." + table; + Map bigLakeConfiguration = + ImmutableMap.of( + CONNECTION_ID, TEST_CONNECTION_ID, + STORAGE_URI, TEST_STORAGE_URI); + runPipeline(getMethod(), tableSpecBase, inputs, bigLakeConfiguration); + assertTablesCreated(tableSpecBase, numTables, false); + assertIcebergTablesCreated(table, numTables); + } + + private void assertIcebergTablesCreated(String tablePrefix, int expectedRows) throws IOException, InterruptedException { + GcsUtil gcsUtil = TestPipeline.testingPipelineOptions().as(GcsOptions.class).getGcsUtil(); + + Objects objects = + gcsUtil.listObjects( + "apache-beam-testing-bq-biglake", + String.format( + "%s/%s/%s/%s", + getClass().getSimpleName(), PROJECT, BIG_QUERY_DATASET_ID, tablePrefix), + null); + + assertEquals(expectedRows, objects.getItems().size()); + } + + private void assertTablesCreated(String tableSpecPrefix, int expectedRows, boolean useWildCard) + throws IOException, InterruptedException { + String query = String.format("SELECT COUNT(*) FROM `%s`", tableSpecPrefix + "*"); + if (!useWildCard) { + query = String.format("SELECT (SELECT COUNT(*) FROM `%s`)", tableSpecPrefix + 0); + for (int i = 1; i < expectedRows; i++) { + query += String.format(" + (SELECT COUNT(*) FROM `%s`)", tableSpecPrefix + i); + } + } TableRow queryResponse = - Iterables.getOnlyElement( - BQ_CLIENT.queryUnflattened( - String.format("SELECT COUNT(*) FROM `%s`", tableSpecPrefix + "*"), - PROJECT, - true, - true)); + Iterables.getOnlyElement(BQ_CLIENT.queryUnflattened(query, PROJECT, true, true)); int numRowsWritten = Integer.parseInt((String) queryResponse.get("f0_")); if (useAtLeastOnce) { assertThat(numRowsWritten, Matchers.greaterThanOrEqualTo(expectedRows)); @@ -118,7 +168,10 @@ private void assertTablesCreated(String tableSpecPrefix, int expectedRows) } private static void runPipeline( - BigQueryIO.Write.Method method, String tableSpecBase, Iterable tableRows) { + BigQueryIO.Write.Method method, + String tableSpecBase, + Iterable tableRows, + @Nullable Map bigLakeConfiguration) { Pipeline p = Pipeline.create(); BigQueryIO.Write write = @@ -131,6 +184,9 @@ private static void runPipeline( write = write.withNumStorageWriteApiStreams(1); write = write.withTriggeringFrequency(Duration.standardSeconds(1)); } + if (bigLakeConfiguration != null) { + write = write.withBigLakeConfiguration(bigLakeConfiguration); + } PCollection input = p.apply("Create test cases", Create.of(tableRows)); input = input.setIsBoundedInternal(PCollection.IsBounded.UNBOUNDED); input.apply("Write using Storage Write API", write); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/ChildPartitionsRecordActionTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/ChildPartitionsRecordActionTest.java index 03d390ea0d5d..5815bf0c6fdd 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/ChildPartitionsRecordActionTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/ChildPartitionsRecordActionTest.java @@ -38,6 +38,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ChildPartition; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ChildPartitionsRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.RestrictionInterrupter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.TimestampRange; import org.apache.beam.sdk.io.gcp.spanner.changestreams.util.TestTransactionAnswer; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; @@ -55,6 +56,7 @@ public class ChildPartitionsRecordActionTest { private ChangeStreamMetrics metrics; private ChildPartitionsRecordAction action; private RestrictionTracker tracker; + private RestrictionInterrupter interrupter; private ManualWatermarkEstimator watermarkEstimator; @Before @@ -64,6 +66,7 @@ public void setUp() { metrics = mock(ChangeStreamMetrics.class); action = new ChildPartitionsRecordAction(dao, metrics); tracker = mock(RestrictionTracker.class); + interrupter = mock(RestrictionInterrupter.class); watermarkEstimator = mock(ManualWatermarkEstimator.class); when(dao.runInTransaction(any(), anyObject())) @@ -93,7 +96,7 @@ public void testRestrictionClaimedAndIsSplitCase() { when(transaction.getPartition("childPartition2")).thenReturn(null); final Optional maybeContinuation = - action.run(partition, record, tracker, watermarkEstimator); + action.run(partition, record, tracker, interrupter, watermarkEstimator); assertEquals(Optional.empty(), maybeContinuation); verify(watermarkEstimator).setWatermark(new Instant(startTimestamp.toSqlTimestamp().getTime())); @@ -144,7 +147,7 @@ public void testRestrictionClaimedAnsIsSplitCaseAndChildExists() { when(transaction.getPartition("childPartition2")).thenReturn(mock(Struct.class)); final Optional maybeContinuation = - action.run(partition, record, tracker, watermarkEstimator); + action.run(partition, record, tracker, interrupter, watermarkEstimator); assertEquals(Optional.empty(), maybeContinuation); verify(watermarkEstimator).setWatermark(new Instant(startTimestamp.toSqlTimestamp().getTime())); @@ -173,7 +176,7 @@ public void testRestrictionClaimedAndIsMergeCaseAndChildNotExists() { when(transaction.getPartition(childPartitionToken)).thenReturn(null); final Optional maybeContinuation = - action.run(partition, record, tracker, watermarkEstimator); + action.run(partition, record, tracker, interrupter, watermarkEstimator); assertEquals(Optional.empty(), maybeContinuation); verify(watermarkEstimator).setWatermark(new Instant(startTimestamp.toSqlTimestamp().getTime())); @@ -213,7 +216,7 @@ public void testRestrictionClaimedAndIsMergeCaseAndChildExists() { when(transaction.getPartition(childPartitionToken)).thenReturn(mock(Struct.class)); final Optional maybeContinuation = - action.run(partition, record, tracker, watermarkEstimator); + action.run(partition, record, tracker, interrupter, watermarkEstimator); assertEquals(Optional.empty(), maybeContinuation); verify(watermarkEstimator).setWatermark(new Instant(startTimestamp.toSqlTimestamp().getTime())); @@ -237,10 +240,35 @@ public void testRestrictionNotClaimed() { when(tracker.tryClaim(startTimestamp)).thenReturn(false); final Optional maybeContinuation = - action.run(partition, record, tracker, watermarkEstimator); + action.run(partition, record, tracker, interrupter, watermarkEstimator); assertEquals(Optional.of(ProcessContinuation.stop()), maybeContinuation); verify(watermarkEstimator, never()).setWatermark(any()); verify(dao, never()).insert(any()); } + + @Test + public void testSoftDeadlineReached() { + final String partitionToken = "partitionToken"; + final Timestamp startTimestamp = Timestamp.ofTimeMicroseconds(10L); + final PartitionMetadata partition = mock(PartitionMetadata.class); + final ChildPartitionsRecord record = + new ChildPartitionsRecord( + startTimestamp, + "recordSequence", + Arrays.asList( + new ChildPartition("childPartition1", partitionToken), + new ChildPartition("childPartition2", partitionToken)), + null); + when(partition.getPartitionToken()).thenReturn(partitionToken); + when(interrupter.tryInterrupt(startTimestamp)).thenReturn(true); + when(tracker.tryClaim(startTimestamp)).thenReturn(true); + + final Optional maybeContinuation = + action.run(partition, record, tracker, interrupter, watermarkEstimator); + + assertEquals(Optional.of(ProcessContinuation.resume()), maybeContinuation); + verify(watermarkEstimator, never()).setWatermark(any()); + verify(dao, never()).insert(any()); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/DataChangeRecordActionTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/DataChangeRecordActionTest.java index ac8d48725299..6569f810812c 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/DataChangeRecordActionTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/DataChangeRecordActionTest.java @@ -30,6 +30,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.estimator.BytesThroughputEstimator; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.RestrictionInterrupter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.TimestampRange; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; @@ -44,6 +45,7 @@ public class DataChangeRecordActionTest { private DataChangeRecordAction action; private PartitionMetadata partition; private RestrictionTracker tracker; + private RestrictionInterrupter interrupter; private OutputReceiver outputReceiver; private ManualWatermarkEstimator watermarkEstimator; private BytesThroughputEstimator throughputEstimator; @@ -54,6 +56,7 @@ public void setUp() { action = new DataChangeRecordAction(throughputEstimator); partition = mock(PartitionMetadata.class); tracker = mock(RestrictionTracker.class); + interrupter = mock(RestrictionInterrupter.class); outputReceiver = mock(OutputReceiver.class); watermarkEstimator = mock(ManualWatermarkEstimator.class); } @@ -69,7 +72,7 @@ public void testRestrictionClaimed() { when(partition.getPartitionToken()).thenReturn(partitionToken); final Optional maybeContinuation = - action.run(partition, record, tracker, outputReceiver, watermarkEstimator); + action.run(partition, record, tracker, interrupter, outputReceiver, watermarkEstimator); assertEquals(Optional.empty(), maybeContinuation); verify(outputReceiver).outputWithTimestamp(record, instant); @@ -87,11 +90,30 @@ public void testRestrictionNotClaimed() { when(partition.getPartitionToken()).thenReturn(partitionToken); final Optional maybeContinuation = - action.run(partition, record, tracker, outputReceiver, watermarkEstimator); + action.run(partition, record, tracker, interrupter, outputReceiver, watermarkEstimator); assertEquals(Optional.of(ProcessContinuation.stop()), maybeContinuation); verify(outputReceiver, never()).outputWithTimestamp(any(), any()); verify(watermarkEstimator, never()).setWatermark(any()); verify(throughputEstimator, never()).update(any(), any()); } + + @Test + public void testSoftDeadlineReached() { + final String partitionToken = "partitionToken"; + final Timestamp timestamp = Timestamp.ofTimeMicroseconds(10L); + final DataChangeRecord record = mock(DataChangeRecord.class); + when(record.getCommitTimestamp()).thenReturn(timestamp); + when(interrupter.tryInterrupt(timestamp)).thenReturn(true); + when(tracker.tryClaim(timestamp)).thenReturn(true); + when(partition.getPartitionToken()).thenReturn(partitionToken); + + final Optional maybeContinuation = + action.run(partition, record, tracker, interrupter, outputReceiver, watermarkEstimator); + + assertEquals(Optional.of(ProcessContinuation.resume()), maybeContinuation); + verify(outputReceiver, never()).outputWithTimestamp(any(), any()); + verify(watermarkEstimator, never()).setWatermark(any()); + verify(throughputEstimator, never()).update(any(), any()); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/HeartbeatRecordActionTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/HeartbeatRecordActionTest.java index 77333bbbc96e..56d1825c8a18 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/HeartbeatRecordActionTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/HeartbeatRecordActionTest.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamMetrics; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.HeartbeatRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.RestrictionInterrupter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.TimestampRange; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; @@ -42,6 +43,7 @@ public class HeartbeatRecordActionTest { private HeartbeatRecordAction action; private PartitionMetadata partition; private RestrictionTracker tracker; + private RestrictionInterrupter interrupter; private ManualWatermarkEstimator watermarkEstimator; @Before @@ -50,6 +52,7 @@ public void setUp() { action = new HeartbeatRecordAction(metrics); partition = mock(PartitionMetadata.class); tracker = mock(RestrictionTracker.class); + interrupter = mock(RestrictionInterrupter.class); watermarkEstimator = mock(ManualWatermarkEstimator.class); } @@ -62,7 +65,12 @@ public void testRestrictionClaimed() { when(partition.getPartitionToken()).thenReturn(partitionToken); final Optional maybeContinuation = - action.run(partition, new HeartbeatRecord(timestamp, null), tracker, watermarkEstimator); + action.run( + partition, + new HeartbeatRecord(timestamp, null), + tracker, + interrupter, + watermarkEstimator); assertEquals(Optional.empty(), maybeContinuation); verify(watermarkEstimator).setWatermark(new Instant(timestamp.toSqlTimestamp().getTime())); @@ -77,9 +85,35 @@ public void testRestrictionNotClaimed() { when(partition.getPartitionToken()).thenReturn(partitionToken); final Optional maybeContinuation = - action.run(partition, new HeartbeatRecord(timestamp, null), tracker, watermarkEstimator); + action.run( + partition, + new HeartbeatRecord(timestamp, null), + tracker, + interrupter, + watermarkEstimator); assertEquals(Optional.of(ProcessContinuation.stop()), maybeContinuation); verify(watermarkEstimator, never()).setWatermark(any()); } + + @Test + public void testSoftDeadlineReached() { + final String partitionToken = "partitionToken"; + final Timestamp timestamp = Timestamp.ofTimeMicroseconds(10L); + + when(interrupter.tryInterrupt(timestamp)).thenReturn(true); + when(tracker.tryClaim(timestamp)).thenReturn(true); + when(partition.getPartitionToken()).thenReturn(partitionToken); + + final Optional maybeContinuation = + action.run( + partition, + new HeartbeatRecord(timestamp, null), + tracker, + interrupter, + watermarkEstimator); + + assertEquals(Optional.of(ProcessContinuation.resume()), maybeContinuation); + verify(watermarkEstimator, never()).setWatermark(any()); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/QueryChangeStreamActionTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/QueryChangeStreamActionTest.java index bf7b0adfd475..c73a62a812bd 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/QueryChangeStreamActionTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/action/QueryChangeStreamActionTest.java @@ -20,6 +20,7 @@ import static org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata.State.SCHEDULED; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -40,6 +41,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.HeartbeatRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.RestrictionInterrupter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction.TimestampRange; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; @@ -144,10 +146,20 @@ public void testQueryChangeStreamWithDataChangeRecord() { when(changeStreamRecordMapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)) .thenReturn(Arrays.asList(record1, record2)); when(dataChangeRecordAction.run( - partition, record1, restrictionTracker, outputReceiver, watermarkEstimator)) + eq(partition), + eq(record1), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(outputReceiver), + eq(watermarkEstimator))) .thenReturn(Optional.empty()); when(dataChangeRecordAction.run( - partition, record2, restrictionTracker, outputReceiver, watermarkEstimator)) + eq(partition), + eq(record2), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(outputReceiver), + eq(watermarkEstimator))) .thenReturn(Optional.of(ProcessContinuation.stop())); when(watermarkEstimator.currentWatermark()).thenReturn(WATERMARK); @@ -157,13 +169,25 @@ public void testQueryChangeStreamWithDataChangeRecord() { assertEquals(ProcessContinuation.stop(), result); verify(dataChangeRecordAction) - .run(partition, record1, restrictionTracker, outputReceiver, watermarkEstimator); + .run( + eq(partition), + eq(record1), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(outputReceiver), + eq(watermarkEstimator)); verify(dataChangeRecordAction) - .run(partition, record2, restrictionTracker, outputReceiver, watermarkEstimator); + .run( + eq(partition), + eq(record2), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(outputReceiver), + eq(watermarkEstimator)); verify(partitionMetadataDao).updateWatermark(PARTITION_TOKEN, WATERMARK_TIMESTAMP); - verify(heartbeatRecordAction, never()).run(any(), any(), any(), any()); - verify(childPartitionsRecordAction, never()).run(any(), any(), any(), any()); + verify(heartbeatRecordAction, never()).run(any(), any(), any(), any(), any()); + verify(childPartitionsRecordAction, never()).run(any(), any(), any(), any(), any()); verify(restrictionTracker, never()).tryClaim(any()); } @@ -188,9 +212,19 @@ public void testQueryChangeStreamWithHeartbeatRecord() { when(resultSet.getMetadata()).thenReturn(resultSetMetadata); when(changeStreamRecordMapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)) .thenReturn(Arrays.asList(record1, record2)); - when(heartbeatRecordAction.run(partition, record1, restrictionTracker, watermarkEstimator)) + when(heartbeatRecordAction.run( + eq(partition), + eq(record1), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator))) .thenReturn(Optional.empty()); - when(heartbeatRecordAction.run(partition, record2, restrictionTracker, watermarkEstimator)) + when(heartbeatRecordAction.run( + eq(partition), + eq(record2), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator))) .thenReturn(Optional.of(ProcessContinuation.stop())); when(watermarkEstimator.currentWatermark()).thenReturn(WATERMARK); @@ -199,12 +233,24 @@ public void testQueryChangeStreamWithHeartbeatRecord() { partition, restrictionTracker, outputReceiver, watermarkEstimator, bundleFinalizer); assertEquals(ProcessContinuation.stop(), result); - verify(heartbeatRecordAction).run(partition, record1, restrictionTracker, watermarkEstimator); - verify(heartbeatRecordAction).run(partition, record2, restrictionTracker, watermarkEstimator); + verify(heartbeatRecordAction) + .run( + eq(partition), + eq(record1), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator)); + verify(heartbeatRecordAction) + .run( + eq(partition), + eq(record2), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator)); verify(partitionMetadataDao).updateWatermark(PARTITION_TOKEN, WATERMARK_TIMESTAMP); - verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any()); - verify(childPartitionsRecordAction, never()).run(any(), any(), any(), any()); + verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any(), any()); + verify(childPartitionsRecordAction, never()).run(any(), any(), any(), any(), any()); verify(restrictionTracker, never()).tryClaim(any()); } @@ -230,10 +276,18 @@ public void testQueryChangeStreamWithChildPartitionsRecord() { when(changeStreamRecordMapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)) .thenReturn(Arrays.asList(record1, record2)); when(childPartitionsRecordAction.run( - partition, record1, restrictionTracker, watermarkEstimator)) + eq(partition), + eq(record1), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator))) .thenReturn(Optional.empty()); when(childPartitionsRecordAction.run( - partition, record2, restrictionTracker, watermarkEstimator)) + eq(partition), + eq(record2), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator))) .thenReturn(Optional.of(ProcessContinuation.stop())); when(watermarkEstimator.currentWatermark()).thenReturn(WATERMARK); @@ -243,13 +297,23 @@ public void testQueryChangeStreamWithChildPartitionsRecord() { assertEquals(ProcessContinuation.stop(), result); verify(childPartitionsRecordAction) - .run(partition, record1, restrictionTracker, watermarkEstimator); + .run( + eq(partition), + eq(record1), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator)); verify(childPartitionsRecordAction) - .run(partition, record2, restrictionTracker, watermarkEstimator); + .run( + eq(partition), + eq(record2), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator)); verify(partitionMetadataDao).updateWatermark(PARTITION_TOKEN, WATERMARK_TIMESTAMP); - verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any()); - verify(heartbeatRecordAction, never()).run(any(), any(), any(), any()); + verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any(), any()); + verify(heartbeatRecordAction, never()).run(any(), any(), any(), any(), any()); verify(restrictionTracker, never()).tryClaim(any()); } @@ -279,7 +343,11 @@ public void testQueryChangeStreamWithRestrictionFromAfterPartitionStart() { when(changeStreamRecordMapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)) .thenReturn(Arrays.asList(record1, record2)); when(childPartitionsRecordAction.run( - partition, record2, restrictionTracker, watermarkEstimator)) + eq(partition), + eq(record2), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator))) .thenReturn(Optional.of(ProcessContinuation.stop())); when(watermarkEstimator.currentWatermark()).thenReturn(WATERMARK); @@ -289,13 +357,23 @@ public void testQueryChangeStreamWithRestrictionFromAfterPartitionStart() { assertEquals(ProcessContinuation.stop(), result); verify(childPartitionsRecordAction) - .run(partition, record1, restrictionTracker, watermarkEstimator); + .run( + eq(partition), + eq(record1), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator)); verify(childPartitionsRecordAction) - .run(partition, record2, restrictionTracker, watermarkEstimator); + .run( + eq(partition), + eq(record2), + eq(restrictionTracker), + any(RestrictionInterrupter.class), + eq(watermarkEstimator)); verify(partitionMetadataDao).updateWatermark(PARTITION_TOKEN, WATERMARK_TIMESTAMP); - verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any()); - verify(heartbeatRecordAction, never()).run(any(), any(), any(), any()); + verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any(), any()); + verify(heartbeatRecordAction, never()).run(any(), any(), any(), any(), any()); verify(restrictionTracker, never()).tryClaim(any()); } @@ -320,9 +398,9 @@ public void testQueryChangeStreamWithStreamFinished() { verify(partitionMetadataDao).updateWatermark(PARTITION_TOKEN, WATERMARK_TIMESTAMP); verify(partitionMetadataDao).updateToFinished(PARTITION_TOKEN); - verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any()); - verify(heartbeatRecordAction, never()).run(any(), any(), any(), any()); - verify(childPartitionsRecordAction, never()).run(any(), any(), any(), any()); + verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any(), any()); + verify(heartbeatRecordAction, never()).run(any(), any(), any(), any(), any()); + verify(childPartitionsRecordAction, never()).run(any(), any(), any(), any(), any()); } private static class BundleFinalizerStub implements BundleFinalizer { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java index 538bdf768664..87588eb8d0a9 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java @@ -149,9 +149,9 @@ public void testQueryChangeStreamMode() { verify(queryChangeStreamAction) .run(partition, tracker, receiver, watermarkEstimator, bundleFinalizer); - verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any()); - verify(heartbeatRecordAction, never()).run(any(), any(), any(), any()); - verify(childPartitionsRecordAction, never()).run(any(), any(), any(), any()); + verify(dataChangeRecordAction, never()).run(any(), any(), any(), any(), any(), any()); + verify(heartbeatRecordAction, never()).run(any(), any(), any(), any(), any()); + verify(childPartitionsRecordAction, never()).run(any(), any(), any(), any(), any()); verify(tracker, never()).tryClaim(any()); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/restriction/RestrictionInterrupterTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/restriction/RestrictionInterrupterTest.java new file mode 100644 index 000000000000..6d376ec528ba --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/restriction/RestrictionInterrupterTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.restriction; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; + +public class RestrictionInterrupterTest { + + @Test + public void testTryInterrupt() { + RestrictionInterrupter interrupter = + new RestrictionInterrupter( + () -> Instant.ofEpochSecond(0), Duration.standardSeconds(30)); + interrupter.setTimeSupplier(() -> Instant.ofEpochSecond(10)); + assertFalse(interrupter.tryInterrupt(1)); + interrupter.setTimeSupplier(() -> Instant.ofEpochSecond(15)); + assertFalse(interrupter.tryInterrupt(2)); + interrupter.setTimeSupplier(() -> Instant.ofEpochSecond(30)); + assertFalse(interrupter.tryInterrupt(3)); + interrupter.setTimeSupplier(() -> Instant.ofEpochSecond(40)); + // Though the deadline has passed same position as previously accepted is not interrupted. + assertFalse(interrupter.tryInterrupt(3)); + assertTrue(interrupter.tryInterrupt(4)); + assertTrue(interrupter.tryInterrupt(5)); + interrupter.setTimeSupplier(() -> Instant.ofEpochSecond(50)); + assertTrue(interrupter.tryInterrupt(5)); + // Even with non-monotonic clock the interrupter will now always interrupt. + interrupter.setTimeSupplier(() -> Instant.ofEpochSecond(40)); + assertTrue(interrupter.tryInterrupt(5)); + } + + @Test + public void testTryInterruptNoPreviousPosition() { + RestrictionInterrupter interrupter = + new RestrictionInterrupter( + () -> Instant.ofEpochSecond(0), Duration.standardSeconds(30)); + interrupter.setTimeSupplier(() -> Instant.ofEpochSecond(40)); + assertFalse(interrupter.tryInterrupt(1)); + // Though the deadline has passed same position as previously accepted is not interrupted. + assertFalse(interrupter.tryInterrupt(1)); + assertTrue(interrupter.tryInterrupt(2)); + interrupter.setTimeSupplier(() -> Instant.ofEpochSecond(50)); + assertTrue(interrupter.tryInterrupt(3)); + } +} diff --git a/sdks/java/io/hadoop-common/build.gradle b/sdks/java/io/hadoop-common/build.gradle index b0303d29ff98..4375001ffa81 100644 --- a/sdks/java/io/hadoop-common/build.gradle +++ b/sdks/java/io/hadoop-common/build.gradle @@ -28,7 +28,7 @@ def hadoopVersions = [ "2102": "2.10.2", "324": "3.2.4", "336": "3.3.6", - "341": "3.4.1", + // "341": "3.4.1", // tests already exercised on the default version ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hadoop-file-system/build.gradle b/sdks/java/io/hadoop-file-system/build.gradle index fafa8b5c7e34..b4ebbfa08c5e 100644 --- a/sdks/java/io/hadoop-file-system/build.gradle +++ b/sdks/java/io/hadoop-file-system/build.gradle @@ -29,7 +29,7 @@ def hadoopVersions = [ "2102": "2.10.2", "324": "3.2.4", "336": "3.3.6", - "341": "3.4.1", + // "341": "3.4.1", // tests already exercised on the default version ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hadoop-format/build.gradle b/sdks/java/io/hadoop-format/build.gradle index 4664005a1fc8..73fc44a0f311 100644 --- a/sdks/java/io/hadoop-format/build.gradle +++ b/sdks/java/io/hadoop-format/build.gradle @@ -33,7 +33,7 @@ def hadoopVersions = [ "2102": "2.10.2", "324": "3.2.4", "336": "3.3.6", - "341": "3.4.1", + // "341": "3.4.1", // tests already exercised on the default version ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hbase/build.gradle b/sdks/java/io/hbase/build.gradle index d85c0fc610bb..07014f2d5e3b 100644 --- a/sdks/java/io/hbase/build.gradle +++ b/sdks/java/io/hbase/build.gradle @@ -34,7 +34,7 @@ test { jvmArgs "-Dtest.build.data.basedirectory=build/test-data" } -def hbase_version = "2.5.5" +def hbase_version = "2.6.1-hadoop3" dependencies { implementation library.java.vendored_guava_32_1_2_jre @@ -46,12 +46,7 @@ dependencies { testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation library.java.junit testImplementation library.java.hamcrest - testImplementation library.java.hadoop_minicluster - testImplementation library.java.hadoop_hdfs - testImplementation library.java.hadoop_common + // shaded-testing-utils has shaded all Hadoop/HBase dependencies testImplementation("org.apache.hbase:hbase-shaded-testing-util:$hbase_version") - testImplementation "org.apache.hbase:hbase-hadoop-compat:$hbase_version:tests" - testImplementation "org.apache.hbase:hbase-hadoop2-compat:$hbase_version:tests" testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") } - diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle index 364c10fa738b..d07904f3465e 100644 --- a/sdks/java/io/hcatalog/build.gradle +++ b/sdks/java/io/hcatalog/build.gradle @@ -33,7 +33,7 @@ def hadoopVersions = [ "2102": "2.10.2", "324": "3.2.4", "336": "3.3.6", - "341": "3.4.1", + // "341": "3.4.1", // tests already exercised on the default version ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} @@ -71,13 +71,21 @@ dependencies { testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") hadoopVersions.each {kv -> "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-common:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-hdfs:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-hdfs-client:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-mapreduce-client-core:$kv.value" } } hadoopVersions.each {kv -> configurations."hadoopVersion$kv.key" { resolutionStrategy { + force "org.apache.hadoop:hadoop-client:$kv.value" force "org.apache.hadoop:hadoop-common:$kv.value" + force "org.apache.hadoop:hadoop-mapreduce-client-core:$kv.value" + force "org.apache.hadoop:hadoop-minicluster:$kv.value" + force "org.apache.hadoop:hadoop-hdfs:$kv.value" + force "org.apache.hadoop:hadoop-hdfs-client:$kv.value" } } } diff --git a/sdks/java/io/iceberg/build.gradle b/sdks/java/io/iceberg/build.gradle index a2d192b67208..f7a9e5c8533d 100644 --- a/sdks/java/io/iceberg/build.gradle +++ b/sdks/java/io/iceberg/build.gradle @@ -37,14 +37,16 @@ def hadoopVersions = [ hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} -def iceberg_version = "1.4.2" +def iceberg_version = "1.6.1" def parquet_version = "1.12.0" def orc_version = "1.9.2" +def hive_version = "3.1.3" dependencies { implementation library.java.vendored_guava_32_1_2_jre implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(path: ":model:pipeline", configuration: "shadow") + implementation library.java.avro implementation library.java.slf4j_api implementation library.java.joda_time implementation "org.apache.parquet:parquet-column:$parquet_version" @@ -55,22 +57,38 @@ dependencies { implementation "org.apache.iceberg:iceberg-orc:$iceberg_version" runtimeOnly "org.apache.iceberg:iceberg-gcp:$iceberg_version" implementation library.java.hadoop_common + implementation library.java.jackson_core + implementation library.java.jackson_databind testImplementation project(":sdks:java:managed") testImplementation library.java.hadoop_client testImplementation library.java.bigdataoss_gcsio testImplementation library.java.bigdataoss_gcs_connector testImplementation library.java.bigdataoss_util_hadoop - testImplementation "org.apache.iceberg:iceberg-gcp:$iceberg_version" testImplementation "org.apache.iceberg:iceberg-data:$iceberg_version" testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation project(":sdks:java:extensions:google-cloud-platform-core") testImplementation library.java.junit + + // Hive catalog test dependencies + testImplementation project(path: ":sdks:java:io:iceberg:hive") + testImplementation "org.apache.iceberg:iceberg-common:$iceberg_version" + testImplementation ("org.apache.iceberg:iceberg-hive-metastore:$iceberg_version") + testImplementation ("org.apache.hive:hive-metastore:$hive_version") + testImplementation "org.assertj:assertj-core:3.11.1" + testRuntimeOnly ("org.apache.hive.hcatalog:hive-hcatalog-core:$hive_version") { + exclude group: "org.apache.hive", module: "hive-exec" + exclude group: "org.apache.parquet", module: "parquet-hadoop-bundle" + } + testRuntimeOnly library.java.slf4j_jdk14 testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") testRuntimeOnly project(path: ":runners:google-cloud-dataflow-java") hadoopVersions.each {kv -> "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-client:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-minicluster:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-hdfs-client:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-mapreduce-client-core:$kv.value" } } @@ -78,6 +96,11 @@ hadoopVersions.each {kv -> configurations."hadoopVersion$kv.key" { resolutionStrategy { force "org.apache.hadoop:hadoop-client:$kv.value" + force "org.apache.hadoop:hadoop-common:$kv.value" + force "org.apache.hadoop:hadoop-mapreduce-client-core:$kv.value" + force "org.apache.hadoop:hadoop-minicluster:$kv.value" + force "org.apache.hadoop:hadoop-hdfs:$kv.value" + force "org.apache.hadoop:hadoop-hdfs-client:$kv.value" } } } @@ -101,7 +124,7 @@ hadoopVersions.each { kv -> task integrationTest(type: Test) { group = "Verification" def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' - def gcpTempLocation = project.findProperty('gcpTempLocation') ?: 'gs://temp-storage-for-end-to-end-tests' + def gcpTempLocation = project.findProperty('gcpTempLocation') ?: 'gs://managed-iceberg-integration-tests' systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ "--project=${gcpProject}", "--tempLocation=${gcpTempLocation}", @@ -117,11 +140,6 @@ task integrationTest(type: Test) { testClassesDirs = sourceSets.test.output.classesDirs } -tasks.register('catalogTests') { - dependsOn integrationTest - dependsOn ":sdks:java:io:iceberg:hive:integrationTest" -} - task loadTest(type: Test) { def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' def gcpTempLocation = project.findProperty('gcpTempLocation') ?: 'gs://temp-storage-for-end-to-end-tests/temp-lt' diff --git a/sdks/java/io/iceberg/hive/build.gradle b/sdks/java/io/iceberg/hive/build.gradle index bfa6c75251c4..7d93a4026775 100644 --- a/sdks/java/io/iceberg/hive/build.gradle +++ b/sdks/java/io/iceberg/hive/build.gradle @@ -21,60 +21,39 @@ plugins { id 'org.apache.beam.module' } applyJavaNature( automaticModuleName: 'org.apache.beam.sdk.io.iceberg.hive', exportJavadoc: false, - shadowClosure: {}, + publish: false, // it's an intermediate jar for io-expansion-service ) description = "Apache Beam :: SDKs :: Java :: IO :: Iceberg :: Hive" ext.summary = "Runtime dependencies needed for Hive catalog integration." def hive_version = "3.1.3" -def iceberg_version = "1.4.2" +def hbase_version = "2.6.1-hadoop3" +def hadoop_version = "3.4.1" +def iceberg_version = "1.6.1" +def avatica_version = "1.25.0" dependencies { // dependencies needed to run with iceberg's hive catalog + // these dependencies are going to be included in io-expansion-service runtimeOnly ("org.apache.iceberg:iceberg-hive-metastore:$iceberg_version") - runtimeOnly project(path: ":sdks:java:io:iceberg:hive:exec", configuration: "shadow") - - // ----- below dependencies are for testing and will not appear in the shaded jar ----- - // Beam IcebergIO dependencies - testImplementation project(path: ":sdks:java:core", configuration: "shadow") - testImplementation project(":sdks:java:managed") - testImplementation project(":sdks:java:io:iceberg") - testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") - testRuntimeOnly library.java.snake_yaml - testRuntimeOnly library.java.bigdataoss_gcs_connector - testRuntimeOnly library.java.hadoop_client - - // needed to set up the test environment - testImplementation "org.apache.iceberg:iceberg-common:$iceberg_version" - testImplementation "org.apache.iceberg:iceberg-core:$iceberg_version" - testImplementation "org.assertj:assertj-core:3.11.1" - testImplementation library.java.junit - - // needed to set up test Hive Metastore and run tests - testImplementation ("org.apache.iceberg:iceberg-hive-metastore:$iceberg_version") - testImplementation project(path: ":sdks:java:io:iceberg:hive:exec", configuration: "shadow") - testRuntimeOnly ("org.apache.hive.hcatalog:hive-hcatalog-core:$hive_version") { - exclude group: "org.apache.hive", module: "hive-exec" - exclude group: "org.apache.parquet", module: "parquet-hadoop-bundle" + // analyzeClassesDependencies fails with "Cannot accept visitor on URL", likely the plugin does not recognize "core" classifier + // use "core" classifier to depend on un-shaded jar + runtimeOnly ("org.apache.hive:hive-exec:$hive_version:core") { + // old hadoop-yarn-server-resourcemanager contains critical log4j vulneribility + exclude group: "org.apache.hadoop", module: "hadoop-yarn-server-resourcemanager" + // old hadoop-yarn-server-resourcemanager contains critical log4j and hadoop vulneribility + exclude group: "org.apache.hbase", module: "hbase-client" + // old calcite leaks old protobuf-java + exclude group: "org.apache.calcite.avatica", module: "avatica" } - testImplementation "org.apache.iceberg:iceberg-parquet:$iceberg_version" - testImplementation "org.apache.parquet:parquet-column:1.12.0" + runtimeOnly ("org.apache.hadoop:hadoop-yarn-server-resourcemanager:$hadoop_version") + runtimeOnly ("org.apache.hbase:hbase-client:$hbase_version") + runtimeOnly ("org.apache.calcite.avatica:avatica-core:$avatica_version") + runtimeOnly ("org.apache.hive:hive-metastore:$hive_version") } -task integrationTest(type: Test) { - group = "Verification" - def gcpTempLocation = project.findProperty('gcpTempLocation') ?: 'gs://temp-storage-for-end-to-end-tests/iceberg-hive-it' - systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ - "--tempLocation=${gcpTempLocation}", - ]) - - // Disable Gradle cache: these ITs interact with live service that should always be considered "out of date" - outputs.upToDateWhen { false } - - include '**/*IT.class' - - maxParallelForks 4 - classpath = sourceSets.test.runtimeClasspath - testClassesDirs = sourceSets.test.output.classesDirs -} \ No newline at end of file +configurations.all { + // the fatjar "parquet-hadoop-bundle" conflicts with "parquet-hadoop" used by org.apache.iceberg:iceberg-parquet + exclude group: "org.apache.parquet", module: "parquet-hadoop-bundle" +} diff --git a/sdks/java/io/iceberg/hive/exec/build.gradle b/sdks/java/io/iceberg/hive/exec/build.gradle deleted file mode 100644 index f266ab2ef4db..000000000000 --- a/sdks/java/io/iceberg/hive/exec/build.gradle +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -plugins { - id 'org.apache.beam.module' - id 'java' - id 'com.github.johnrengelman.shadow' -} - -dependencies { - implementation("org.apache.hive:hive-exec:3.1.3") - permitUnusedDeclared("org.apache.hive:hive-exec:3.1.3") -} - -configurations { - shadow -} - -artifacts { - shadow(archives(shadowJar) { - builtBy shadowJar - }) -} - -shadowJar { - zip64 true - - def problematicPackages = [ - 'com.google.protobuf', - 'com.google.common', - 'shaded.parquet', - 'org.apache.parquet', - 'org.joda' - ] - - problematicPackages.forEach { - relocate it, getJavaRelocatedPath("iceberg.hive.${it}") - } - - version "3.1.3" - mergeServiceFiles() - - exclude 'LICENSE' - exclude( - 'org/xml/**', - 'javax/**', - 'com/sun/**' - ) -} -description = "Apache Beam :: SDKs :: Java :: IO :: Iceberg :: Hive :: Exec" -ext.summary = "A copy of the hive-exec dependency with some popular libraries relocated." diff --git a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/IcebergHiveCatalogIT.java b/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/IcebergHiveCatalogIT.java deleted file mode 100644 index ca4d862c2c72..000000000000 --- a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/IcebergHiveCatalogIT.java +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.iceberg.hive; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsInAnyOrder; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.LongStream; -import org.apache.beam.sdk.io.iceberg.IcebergUtils; -import org.apache.beam.sdk.io.iceberg.hive.testutils.HiveMetastoreExtension; -import org.apache.beam.sdk.managed.Managed; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.SimpleFunction; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; -import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.metastore.api.Database; -import org.apache.iceberg.AppendFiles; -import org.apache.iceberg.CatalogProperties; -import org.apache.iceberg.CatalogUtil; -import org.apache.iceberg.CombinedScanTask; -import org.apache.iceberg.DataFile; -import org.apache.iceberg.FileFormat; -import org.apache.iceberg.FileScanTask; -import org.apache.iceberg.ManifestFiles; -import org.apache.iceberg.ManifestWriter; -import org.apache.iceberg.Table; -import org.apache.iceberg.TableScan; -import org.apache.iceberg.catalog.TableIdentifier; -import org.apache.iceberg.data.Record; -import org.apache.iceberg.data.parquet.GenericParquetReaders; -import org.apache.iceberg.data.parquet.GenericParquetWriter; -import org.apache.iceberg.encryption.InputFilesDecryptor; -import org.apache.iceberg.hive.HiveCatalog; -import org.apache.iceberg.io.CloseableIterable; -import org.apache.iceberg.io.DataWriter; -import org.apache.iceberg.io.InputFile; -import org.apache.iceberg.io.OutputFile; -import org.apache.iceberg.parquet.Parquet; -import org.apache.iceberg.util.DateTimeUtil; -import org.apache.thrift.TException; -import org.joda.time.DateTime; -import org.joda.time.DateTimeZone; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; - -/** - * Read and write test for {@link Managed} {@link org.apache.beam.sdk.io.iceberg.IcebergIO} using - * {@link HiveCatalog}. - * - *

Spins up a local Hive metastore to manage the Iceberg table. Warehouse path is set to a GCS - * bucket. - */ -public class IcebergHiveCatalogIT { - private static final Schema DOUBLY_NESTED_ROW_SCHEMA = - Schema.builder() - .addStringField("doubly_nested_str") - .addInt64Field("doubly_nested_float") - .build(); - - private static final Schema NESTED_ROW_SCHEMA = - Schema.builder() - .addStringField("nested_str") - .addInt32Field("nested_int") - .addFloatField("nested_float") - .addRowField("nested_row", DOUBLY_NESTED_ROW_SCHEMA) - .build(); - private static final Schema BEAM_SCHEMA = - Schema.builder() - .addStringField("str") - .addBooleanField("bool") - .addNullableInt32Field("nullable_int") - .addNullableInt64Field("nullable_long") - .addArrayField("arr_long", Schema.FieldType.INT64) - .addRowField("row", NESTED_ROW_SCHEMA) - .addNullableRowField("nullable_row", NESTED_ROW_SCHEMA) - .addDateTimeField("datetime_tz") - .addLogicalTypeField("datetime", SqlTypes.DATETIME) - .addLogicalTypeField("date", SqlTypes.DATE) - .addLogicalTypeField("time", SqlTypes.TIME) - .build(); - - private static final SimpleFunction ROW_FUNC = - new SimpleFunction() { - @Override - public Row apply(Long num) { - String strNum = Long.toString(num); - Row nestedRow = - Row.withSchema(NESTED_ROW_SCHEMA) - .addValue("nested_str_value_" + strNum) - .addValue(Integer.valueOf(strNum)) - .addValue(Float.valueOf(strNum + "." + strNum)) - .addValue( - Row.withSchema(DOUBLY_NESTED_ROW_SCHEMA) - .addValue("doubly_nested_str_value_" + strNum) - .addValue(num) - .build()) - .build(); - - return Row.withSchema(BEAM_SCHEMA) - .addValue("str_value_" + strNum) - .addValue(num % 2 == 0) - .addValue(Integer.valueOf(strNum)) - .addValue(num) - .addValue(LongStream.range(1, num % 10).boxed().collect(Collectors.toList())) - .addValue(nestedRow) - .addValue(num % 2 == 0 ? null : nestedRow) - .addValue(new DateTime(num).withZone(DateTimeZone.forOffsetHoursMinutes(3, 25))) - .addValue(DateTimeUtil.timestampFromMicros(num)) - .addValue(DateTimeUtil.dateFromDays(Integer.parseInt(strNum))) - .addValue(DateTimeUtil.timeFromMicros(num)) - .build(); - } - }; - - private static final org.apache.iceberg.Schema ICEBERG_SCHEMA = - IcebergUtils.beamSchemaToIcebergSchema(BEAM_SCHEMA); - private static final SimpleFunction RECORD_FUNC = - new SimpleFunction() { - @Override - public Record apply(Row input) { - return IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, input); - } - }; - - private static HiveMetastoreExtension hiveMetastoreExtension; - - @Rule public TestPipeline writePipeline = TestPipeline.create(); - - @Rule public TestPipeline readPipeline = TestPipeline.create(); - - private static final String TEST_CATALOG = "test_catalog"; - private static final String TEST_TABLE = "test_table"; - private static HiveCatalog catalog; - private static final String TEST_DB = "test_db_" + System.nanoTime(); - - @BeforeClass - public static void setUp() throws TException { - String warehousePath = TestPipeline.testingPipelineOptions().getTempLocation(); - hiveMetastoreExtension = new HiveMetastoreExtension(warehousePath); - catalog = - (HiveCatalog) - CatalogUtil.loadCatalog( - HiveCatalog.class.getName(), - TEST_CATALOG, - ImmutableMap.of( - CatalogProperties.CLIENT_POOL_CACHE_EVICTION_INTERVAL_MS, - String.valueOf(TimeUnit.SECONDS.toMillis(10))), - hiveMetastoreExtension.hiveConf()); - - String dbPath = hiveMetastoreExtension.metastore().getDatabasePath(TEST_DB); - Database db = new Database(TEST_DB, "description", dbPath, Maps.newHashMap()); - hiveMetastoreExtension.metastoreClient().createDatabase(db); - } - - @AfterClass - public static void cleanup() throws Exception { - hiveMetastoreExtension.cleanup(); - } - - private Map getManagedIcebergConfig(TableIdentifier table) { - String metastoreUri = hiveMetastoreExtension.hiveConf().getVar(HiveConf.ConfVars.METASTOREURIS); - - Map confProperties = - ImmutableMap.builder() - .put(HiveConf.ConfVars.METASTOREURIS.varname, metastoreUri) - .build(); - - return ImmutableMap.builder() - .put("table", table.toString()) - .put("config_properties", confProperties) - .build(); - } - - @Test - public void testReadWithHiveCatalog() throws IOException { - TableIdentifier tableIdentifier = - TableIdentifier.parse(String.format("%s.%s", TEST_DB, TEST_TABLE + "_read_test")); - Table table = catalog.createTable(tableIdentifier, ICEBERG_SCHEMA); - - List expectedRows = - LongStream.range(1, 1000).boxed().map(ROW_FUNC::apply).collect(Collectors.toList()); - List records = - expectedRows.stream().map(RECORD_FUNC::apply).collect(Collectors.toList()); - - // write iceberg records with hive catalog - String filepath = table.location() + "/" + UUID.randomUUID(); - DataWriter writer = - Parquet.writeData(table.io().newOutputFile(filepath)) - .schema(ICEBERG_SCHEMA) - .createWriterFunc(GenericParquetWriter::buildWriter) - .overwrite() - .withSpec(table.spec()) - .build(); - for (Record rec : records) { - writer.write(rec); - } - writer.close(); - AppendFiles appendFiles = table.newAppend(); - String manifestFilename = FileFormat.AVRO.addExtension(filepath + ".manifest"); - OutputFile outputFile = table.io().newOutputFile(manifestFilename); - ManifestWriter manifestWriter; - try (ManifestWriter openWriter = ManifestFiles.write(table.spec(), outputFile)) { - openWriter.add(writer.toDataFile()); - manifestWriter = openWriter; - } - appendFiles.appendManifest(manifestWriter.toManifestFile()); - appendFiles.commit(); - - // Run Managed Iceberg read - PCollection outputRows = - readPipeline - .apply( - Managed.read(Managed.ICEBERG).withConfig(getManagedIcebergConfig(tableIdentifier))) - .getSinglePCollection(); - PAssert.that(outputRows).containsInAnyOrder(expectedRows); - readPipeline.run().waitUntilFinish(); - } - - @Test - public void testWriteWithHiveCatalog() { - TableIdentifier tableIdentifier = - TableIdentifier.parse(String.format("%s.%s", TEST_DB, TEST_TABLE + "_write_test")); - catalog.createTable(tableIdentifier, IcebergUtils.beamSchemaToIcebergSchema(BEAM_SCHEMA)); - - List inputRows = - LongStream.range(1, 1000).mapToObj(ROW_FUNC::apply).collect(Collectors.toList()); - List expectedRecords = - inputRows.stream().map(RECORD_FUNC::apply).collect(Collectors.toList()); - - // Run Managed Iceberg write - writePipeline - .apply(Create.of(inputRows)) - .setRowSchema(BEAM_SCHEMA) - .apply(Managed.write(Managed.ICEBERG).withConfig(getManagedIcebergConfig(tableIdentifier))); - writePipeline.run().waitUntilFinish(); - - // read back the records and check everything's there - Table table = catalog.loadTable(tableIdentifier); - TableScan tableScan = table.newScan().project(ICEBERG_SCHEMA); - List writtenRecords = new ArrayList<>(); - for (CombinedScanTask task : tableScan.planTasks()) { - InputFilesDecryptor decryptor = new InputFilesDecryptor(task, table.io(), table.encryption()); - for (FileScanTask fileTask : task.files()) { - InputFile inputFile = decryptor.getInputFile(fileTask); - CloseableIterable iterable = - Parquet.read(inputFile) - .split(fileTask.start(), fileTask.length()) - .project(ICEBERG_SCHEMA) - .createReaderFunc( - fileSchema -> GenericParquetReaders.buildReader(ICEBERG_SCHEMA, fileSchema)) - .filter(fileTask.residual()) - .build(); - - for (Record rec : iterable) { - writtenRecords.add(rec); - } - } - } - assertThat(expectedRecords, containsInAnyOrder(writtenRecords.toArray())); - } -} diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java index d9768114e7c6..deec779c6cc9 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java @@ -47,7 +47,6 @@ import org.apache.iceberg.Snapshot; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.Catalog; -import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.io.FileIO; import org.apache.iceberg.io.OutputFile; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; @@ -134,7 +133,7 @@ public void processElement( return; } - Table table = getCatalog().loadTable(TableIdentifier.parse(element.getKey())); + Table table = getCatalog().loadTable(IcebergUtils.parseTableIdentifier(element.getKey())); // vast majority of the time, we will simply append data files. // in the rare case we get a batch that contains multiple partition specs, we will group diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java index bf00bf8519fc..d58ac8696d37 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java @@ -25,6 +25,7 @@ import org.apache.iceberg.DataFile; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.catalog.TableIdentifierParser; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; @AutoValue @@ -41,7 +42,7 @@ abstract class FileWriteResult { @SchemaIgnore public TableIdentifier getTableIdentifier() { if (cachedTableIdentifier == null) { - cachedTableIdentifier = TableIdentifier.parse(getTableIdentifierString()); + cachedTableIdentifier = IcebergUtils.parseTableIdentifier(getTableIdentifierString()); } return cachedTableIdentifier; } @@ -67,7 +68,7 @@ abstract static class Builder { @SchemaIgnore public Builder setTableIdentifier(TableIdentifier tableId) { - return setTableIdentifierString(tableId.toString()); + return setTableIdentifierString(TableIdentifierParser.toJson(tableId)); } public abstract FileWriteResult build(); diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java index d44149fda08e..951442e2c95f 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java @@ -31,7 +31,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; -import org.apache.iceberg.catalog.TableIdentifier; /** * SchemaTransform implementation for {@link IcebergIO#readRows}. Reads records from Iceberg and @@ -86,7 +85,7 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { .getPipeline() .apply( IcebergIO.readRows(configuration.getIcebergCatalog()) - .from(TableIdentifier.parse(configuration.getTable()))); + .from(IcebergUtils.parseTableIdentifier(configuration.getTable()))); return PCollectionRowTuple.of(OUTPUT_TAG, output); } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergScanConfig.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergScanConfig.java index 60372b172af7..640283d83c2e 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergScanConfig.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergScanConfig.java @@ -23,6 +23,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.catalog.TableIdentifierParser; import org.apache.iceberg.expressions.Expression; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.Nullable; @@ -51,7 +52,9 @@ public enum ScanType { public Table getTable() { if (cachedTable == null) { cachedTable = - getCatalogConfig().catalog().loadTable(TableIdentifier.parse(getTableIdentifier())); + getCatalogConfig() + .catalog() + .loadTable(IcebergUtils.parseTableIdentifier(getTableIdentifier())); } return cachedTable; } @@ -126,7 +129,7 @@ public abstract static class Builder { public abstract Builder setTableIdentifier(String tableIdentifier); public Builder setTableIdentifier(TableIdentifier tableIdentifier) { - return this.setTableIdentifier(tableIdentifier.toString()); + return this.setTableIdentifier(TableIdentifierParser.toJson(tableIdentifier)); } public Builder setTableIdentifier(String... names) { diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergUtils.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergUtils.java index ef19a5881366..bd2f743172dc 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergUtils.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergUtils.java @@ -19,6 +19,9 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import java.nio.ByteBuffer; import java.time.LocalDate; import java.time.LocalDateTime; @@ -36,6 +39,8 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.catalog.TableIdentifierParser; import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.Record; import org.apache.iceberg.types.Type; @@ -47,6 +52,9 @@ /** Utilities for converting between Beam and Iceberg types, made public for user's convenience. */ public class IcebergUtils { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private IcebergUtils() {} private static final Map BEAM_TYPES_TO_ICEBERG_TYPES = @@ -506,4 +514,13 @@ private static Object getLogicalTypeValue(Object icebergValue, Schema.FieldType // LocalDateTime, LocalDate, LocalTime return icebergValue; } + + public static TableIdentifier parseTableIdentifier(String table) { + try { + JsonNode jsonNode = OBJECT_MAPPER.readTree(table); + return TableIdentifierParser.fromJson(jsonNode); + } catch (JsonProcessingException e) { + return TableIdentifier.parse(table); + } + } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/OneTableDynamicDestinations.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/OneTableDynamicDestinations.java index 861a8ad198a8..be810aa20a13 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/OneTableDynamicDestinations.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/OneTableDynamicDestinations.java @@ -41,7 +41,7 @@ class OneTableDynamicDestinations implements DynamicDestinations, Externalizable @VisibleForTesting TableIdentifier getTableIdentifier() { if (tableId == null) { - tableId = TableIdentifier.parse(checkStateNotNull(tableIdString)); + tableId = IcebergUtils.parseTableIdentifier(checkStateNotNull(tableIdString)); } return tableId; } @@ -86,6 +86,6 @@ public void writeExternal(ObjectOutput out) throws IOException { @Override public void readExternal(ObjectInput in) throws IOException { tableIdString = in.readUTF(); - tableId = TableIdentifier.parse(tableIdString); + tableId = IcebergUtils.parseTableIdentifier(tableIdString); } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/PortableIcebergDestinations.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/PortableIcebergDestinations.java index 47f661bba3f8..58f70463bc76 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/PortableIcebergDestinations.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/PortableIcebergDestinations.java @@ -24,7 +24,6 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.ValueInSingleWindow; import org.apache.iceberg.FileFormat; -import org.apache.iceberg.catalog.TableIdentifier; import org.checkerframework.checker.nullness.qual.Nullable; class PortableIcebergDestinations implements DynamicDestinations { @@ -73,7 +72,7 @@ public String getTableStringIdentifier(ValueInSingleWindow element) { @Override public IcebergDestination instantiateDestination(String dest) { return IcebergDestination.builder() - .setTableIdentifier(TableIdentifier.parse(dest)) + .setTableIdentifier(IcebergUtils.parseTableIdentifier(dest)) .setTableCreateConfig(null) .setFileFormat(FileFormat.fromString(fileFormat)) .build(); diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java index 255fce9ece4e..4c21a0175ab0 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java @@ -21,6 +21,11 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.io.IOException; +import java.time.LocalDateTime; +import java.time.YearMonth; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -31,6 +36,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalNotification; @@ -38,14 +44,20 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.iceberg.DataFile; import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.Record; import org.apache.iceberg.exceptions.AlreadyExistsException; import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -90,6 +102,7 @@ class DestinationState { final Cache writers; private final List dataFiles = Lists.newArrayList(); @VisibleForTesting final Map writerCounts = Maps.newHashMap(); + private final Map partitionFieldMap = Maps.newHashMap(); private final List exceptions = Lists.newArrayList(); DestinationState(IcebergDestination icebergDestination, Table table) { @@ -98,6 +111,9 @@ class DestinationState { this.spec = table.spec(); this.partitionKey = new PartitionKey(spec, schema); this.table = table; + for (PartitionField partitionField : spec.fields()) { + partitionFieldMap.put(partitionField.name(), partitionField); + } // build a cache of RecordWriters. // writers will expire after 1 min of idle time. @@ -123,7 +139,9 @@ class DestinationState { throw rethrow; } openWriters--; - dataFiles.add(SerializableDataFile.from(recordWriter.getDataFile(), pk)); + String partitionPath = getPartitionDataPath(pk.toPath(), partitionFieldMap); + dataFiles.add( + SerializableDataFile.from(recordWriter.getDataFile(), partitionPath)); }) .build(); } @@ -136,7 +154,7 @@ class DestinationState { * can't create a new writer, the {@link Record} is rejected and {@code false} is returned. */ boolean write(Record record) { - partitionKey.partition(record); + partitionKey.partition(getPartitionableRecord(record)); if (!writers.asMap().containsKey(partitionKey) && openWriters >= maxNumWriters) { return false; @@ -185,8 +203,65 @@ private RecordWriter createWriter(PartitionKey partitionKey) { e); } } + + /** + * Resolves an input {@link Record}'s partition values and returns another {@link Record} that + * can be applied to the destination's {@link PartitionSpec}. + */ + private Record getPartitionableRecord(Record record) { + if (spec.isUnpartitioned()) { + return record; + } + Record output = GenericRecord.create(schema); + for (PartitionField partitionField : spec.fields()) { + Transform transform = partitionField.transform(); + Types.NestedField field = schema.findField(partitionField.sourceId()); + String name = field.name(); + Object value = record.getField(name); + @Nullable Literal literal = Literal.of(value.toString()).to(field.type()); + if (literal == null || transform.isVoid() || transform.isIdentity()) { + output.setField(name, value); + } else { + output.setField(name, literal.value()); + } + } + return output; + } } + /** + * Returns an equivalent partition path that is made up of partition data. Needed to reconstruct a + * {@link DataFile}. + */ + @VisibleForTesting + static String getPartitionDataPath( + String partitionPath, Map partitionFieldMap) { + if (partitionPath.isEmpty() || partitionFieldMap.isEmpty()) { + return partitionPath; + } + List resolved = new ArrayList<>(); + for (String partition : Splitter.on('/').splitToList(partitionPath)) { + List nameAndValue = Splitter.on('=').splitToList(partition); + String name = nameAndValue.get(0); + String value = nameAndValue.get(1); + String transformName = + Preconditions.checkArgumentNotNull(partitionFieldMap.get(name)).transform().toString(); + if (Transforms.month().toString().equals(transformName)) { + int month = YearMonth.parse(value).getMonthValue(); + value = String.valueOf(month); + } else if (Transforms.hour().toString().equals(transformName)) { + long hour = ChronoUnit.HOURS.between(EPOCH, LocalDateTime.parse(value, HOUR_FORMATTER)); + value = String.valueOf(hour); + } + resolved.add(name + "=" + value); + } + return String.join("/", resolved); + } + + private static final DateTimeFormatter HOUR_FORMATTER = + DateTimeFormatter.ofPattern("yyyy-MM-dd-HH"); + private static final LocalDateTime EPOCH = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC); + private final Catalog catalog; private final String filePrefix; private final long maxFileSize; diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanTaskReader.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanTaskReader.java index b7cb42b2eacb..5784dfd79744 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanTaskReader.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanTaskReader.java @@ -21,15 +21,22 @@ import java.io.IOException; import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Map; import java.util.NoSuchElementException; import java.util.Queue; +import java.util.Set; +import java.util.function.BiFunction; import javax.annotation.Nullable; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.values.Row; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.data.IdentityPartitionConverters; import org.apache.iceberg.data.Record; import org.apache.iceberg.data.avro.DataReader; import org.apache.iceberg.data.orc.GenericOrcReader; @@ -42,6 +49,9 @@ import org.apache.iceberg.io.InputFile; import org.apache.iceberg.orc.ORC; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.PartitionUtil; import org.checkerframework.checker.nullness.qual.NonNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -112,6 +122,8 @@ public boolean advance() throws IOException { FileScanTask fileTask = fileScanTasks.remove(); DataFile file = fileTask.file(); InputFile input = decryptor.getInputFile(fileTask); + Map idToConstants = + constantsMap(fileTask, IdentityPartitionConverters::convertConstant, project); CloseableIterable iterable; switch (file.format()) { @@ -121,7 +133,9 @@ public boolean advance() throws IOException { ORC.read(input) .split(fileTask.start(), fileTask.length()) .project(project) - .createReaderFunc(fileSchema -> GenericOrcReader.buildReader(project, fileSchema)) + .createReaderFunc( + fileSchema -> + GenericOrcReader.buildReader(project, fileSchema, idToConstants)) .filter(fileTask.residual()) .build(); break; @@ -132,7 +146,8 @@ public boolean advance() throws IOException { .split(fileTask.start(), fileTask.length()) .project(project) .createReaderFunc( - fileSchema -> GenericParquetReaders.buildReader(project, fileSchema)) + fileSchema -> + GenericParquetReaders.buildReader(project, fileSchema, idToConstants)) .filter(fileTask.residual()) .build(); break; @@ -142,7 +157,8 @@ public boolean advance() throws IOException { Avro.read(input) .split(fileTask.start(), fileTask.length()) .project(project) - .createReaderFunc(DataReader::create) + .createReaderFunc( + fileSchema -> DataReader.create(project, fileSchema, idToConstants)) .build(); break; default: @@ -155,6 +171,20 @@ public boolean advance() throws IOException { return false; } + private Map constantsMap( + FileScanTask task, BiFunction converter, Schema schema) { + PartitionSpec spec = task.spec(); + Set idColumns = spec.identitySourceIds(); + Schema partitionSchema = TypeUtil.select(schema, idColumns); + boolean projectsIdentityPartitionColumns = !partitionSchema.columns().isEmpty(); + + if (projectsIdentityPartitionColumns) { + return PartitionUtil.constantsMap(task, converter); + } else { + return Collections.emptyMap(); + } + } + @Override public Row getCurrent() throws NoSuchElementException { if (current == null) { diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java index 59b456162008..eef2b154d243 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java @@ -116,13 +116,14 @@ abstract static class Builder { * Create a {@link SerializableDataFile} from a {@link DataFile} and its associated {@link * PartitionKey}. */ - static SerializableDataFile from(DataFile f, PartitionKey key) { + static SerializableDataFile from(DataFile f, String partitionPath) { + return SerializableDataFile.builder() .setPath(f.path().toString()) .setFileFormat(f.format().toString()) .setRecordCount(f.recordCount()) .setFileSizeInBytes(f.fileSizeInBytes()) - .setPartitionPath(key.toPath()) + .setPartitionPath(partitionPath) .setPartitionSpecId(f.specId()) .setKeyMetadata(f.keyMetadata()) .setSplitOffsets(f.splitOffsets()) diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java index fe4a07dedfdf..6ff3bdf6a4ff 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java @@ -17,9 +17,12 @@ */ package org.apache.beam.sdk.io.iceberg; +import static org.apache.beam.sdk.io.iceberg.TestFixtures.createRecord; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.UUID; @@ -35,18 +38,21 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.types.Types; import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class IcebergIOReadTest { private static final Logger LOG = LoggerFactory.getLogger(IcebergIOReadTest.class); @@ -57,6 +63,21 @@ public class IcebergIOReadTest { @Rule public TestPipeline testPipeline = TestPipeline.create(); + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList( + new Object[][] { + {String.format("{\"namespace\": [\"default\"], \"name\": \"%s\"}", tableId())}, + {String.format("default.%s", tableId())}, + }); + } + + public static String tableId() { + return "table" + Long.toString(UUID.randomUUID().hashCode(), 16); + } + + @Parameterized.Parameter public String tableStringIdentifier; + static class PrintRow extends DoFn { @ProcessElement @@ -68,8 +89,7 @@ public void process(@Element Row row, OutputReceiver output) throws Excepti @Test public void testSimpleScan() throws Exception { - TableIdentifier tableId = - TableIdentifier.of("default", "table" + Long.toString(UUID.randomUUID().hashCode(), 16)); + TableIdentifier tableId = IcebergUtils.parseTableIdentifier(tableStringIdentifier); Table simpleTable = warehouse.createTable(tableId, TestFixtures.SCHEMA); final Schema schema = IcebergUtils.icebergSchemaToBeamSchema(TestFixtures.SCHEMA); @@ -122,4 +142,72 @@ public void testSimpleScan() throws Exception { testPipeline.run(); } + + @Test + public void testIdentityColumnScan() throws Exception { + TableIdentifier tableId = + TableIdentifier.of("default", "table" + Long.toString(UUID.randomUUID().hashCode(), 16)); + Table simpleTable = warehouse.createTable(tableId, TestFixtures.SCHEMA); + + String identityColumnName = "identity"; + String identityColumnValue = "some-value"; + simpleTable.updateSchema().addColumn(identityColumnName, Types.StringType.get()).commit(); + simpleTable.updateSpec().addField(identityColumnName).commit(); + + PartitionSpec spec = simpleTable.spec(); + PartitionKey partitionKey = new PartitionKey(simpleTable.spec(), simpleTable.schema()); + partitionKey.set(0, identityColumnValue); + + simpleTable + .newFastAppend() + .appendFile( + warehouse.writeRecords( + "file1s1.parquet", + TestFixtures.SCHEMA, + spec, + partitionKey, + TestFixtures.FILE1SNAPSHOT1)) + .commit(); + + final Schema schema = IcebergUtils.icebergSchemaToBeamSchema(simpleTable.schema()); + final List expectedRows = + Stream.of(TestFixtures.FILE1SNAPSHOT1_DATA) + .flatMap(List::stream) + .map( + d -> + ImmutableMap.builder() + .putAll(d) + .put(identityColumnName, identityColumnValue) + .build()) + .map(r -> createRecord(simpleTable.schema(), r)) + .map(record -> IcebergUtils.icebergRecordToBeamRow(schema, record)) + .collect(Collectors.toList()); + + Map catalogProps = + ImmutableMap.builder() + .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) + .put("warehouse", warehouse.location) + .build(); + + IcebergCatalogConfig catalogConfig = + IcebergCatalogConfig.builder() + .setCatalogName("name") + .setCatalogProperties(catalogProps) + .build(); + + PCollection output = + testPipeline + .apply(IcebergIO.readRows(catalogConfig).from(tableId)) + .apply(ParDo.of(new PrintRow())) + .setCoder(RowCoder.of(IcebergUtils.icebergSchemaToBeamSchema(simpleTable.schema()))); + + PAssert.that(output) + .satisfies( + (Iterable rows) -> { + assertThat(rows, containsInAnyOrder(expectedRows.toArray())); + return null; + }); + + testPipeline.run(); + } } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergUtilsTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergUtilsTest.java index 134f05c34bfb..918c6b1146ee 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergUtilsTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergUtilsTest.java @@ -19,11 +19,13 @@ import static org.apache.beam.sdk.io.iceberg.IcebergUtils.TypeAndMaxId; import static org.apache.beam.sdk.io.iceberg.IcebergUtils.beamFieldTypeToIcebergFieldType; +import static org.apache.beam.sdk.io.iceberg.IcebergUtils.parseTableIdentifier; import static org.apache.iceberg.types.Types.NestedField.optional; import static org.apache.iceberg.types.Types.NestedField.required; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.math.BigDecimal; @@ -32,6 +34,7 @@ import java.time.OffsetDateTime; import java.time.ZoneOffset; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Map; import org.apache.beam.sdk.schemas.Schema; @@ -49,6 +52,7 @@ import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; /** Test class for {@link IcebergUtils}. */ @RunWith(Enclosed.class) @@ -802,4 +806,40 @@ public void testStructIcebergSchemaToBeamSchema() { assertEquals(BEAM_SCHEMA_STRUCT, convertedBeamSchema); } } + + @RunWith(Parameterized.class) + public static class TableIdentifierParseTests { + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList( + new Object[][] { + { + "{\"namespace\": [\"dogs\", \"owners.and.handlers\"], \"name\": \"food\"}", + "dogs.owners.and.handlers.food", + true + }, + {"dogs.owners.and.handlers.food", "dogs.owners.and.handlers.food", true}, + {"{\"name\": \"food\"}", "food", true}, + {"{\"table_name\": \"food\"}", "{\"table_name\": \"food\"}", false}, + }); + } + + @Parameterized.Parameter public String input; + + @Parameterized.Parameter(1) + public String expected; + + @Parameterized.Parameter(2) + public boolean shouldSucceed; + + @Test + public void test() { + if (shouldSucceed) { + assertEquals(expected, parseTableIdentifier(input).toString()); + } else { + assertThrows(IllegalArgumentException.class, () -> parseTableIdentifier(input)); + } + } + } } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java index 47dc9aa425dd..9834547c4741 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java @@ -23,14 +23,19 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertEquals; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; @@ -49,12 +54,16 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.data.IcebergGenerics; import org.apache.iceberg.data.Record; +import org.apache.iceberg.util.DateTimeUtil; import org.checkerframework.checker.nullness.qual.Nullable; import org.hamcrest.Matchers; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.ClassRule; @@ -360,4 +369,93 @@ public Void apply(Iterable input) { return null; } } + + @Test + public void testWritePartitionedData() { + Schema schema = + Schema.builder() + .addStringField("str") + .addInt32Field("int") + .addLogicalTypeField("y_date", SqlTypes.DATE) + .addLogicalTypeField("y_datetime", SqlTypes.DATETIME) + .addDateTimeField("y_datetime_tz") + .addLogicalTypeField("m_date", SqlTypes.DATE) + .addLogicalTypeField("m_datetime", SqlTypes.DATETIME) + .addDateTimeField("m_datetime_tz") + .addLogicalTypeField("d_date", SqlTypes.DATE) + .addLogicalTypeField("d_datetime", SqlTypes.DATETIME) + .addDateTimeField("d_datetime_tz") + .addLogicalTypeField("h_datetime", SqlTypes.DATETIME) + .addDateTimeField("h_datetime_tz") + .build(); + org.apache.iceberg.Schema icebergSchema = IcebergUtils.beamSchemaToIcebergSchema(schema); + PartitionSpec spec = + PartitionSpec.builderFor(icebergSchema) + .identity("str") + .bucket("int", 5) + .year("y_date") + .year("y_datetime") + .year("y_datetime_tz") + .month("m_date") + .month("m_datetime") + .month("m_datetime_tz") + .day("d_date") + .day("d_datetime") + .day("d_datetime_tz") + .hour("h_datetime") + .hour("h_datetime_tz") + .build(); + String identifier = "default.table_" + Long.toString(UUID.randomUUID().hashCode(), 16); + + warehouse.createTable(TableIdentifier.parse(identifier), icebergSchema, spec); + Map config = + ImmutableMap.of( + "table", + identifier, + "catalog_properties", + ImmutableMap.of("type", "hadoop", "warehouse", warehouse.location)); + + List rows = new ArrayList<>(); + for (int i = 0; i < 30; i++) { + long millis = i * 100_00_000_000L; + LocalDate localDate = DateTimeUtil.dateFromDays(i * 100); + LocalDateTime localDateTime = DateTimeUtil.timestampFromMicros(millis * 1000); + DateTime dateTime = new DateTime(millis).withZone(DateTimeZone.forOffsetHoursMinutes(3, 25)); + Row row = + Row.withSchema(schema) + .addValues( + "str_" + i, + i, + localDate, + localDateTime, + dateTime, + localDate, + localDateTime, + dateTime, + localDate, + localDateTime, + dateTime, + localDateTime, + dateTime) + .build(); + rows.add(row); + } + + PCollection result = + testPipeline + .apply("Records To Add", Create.of(rows)) + .setRowSchema(schema) + .apply(Managed.write(Managed.ICEBERG).withConfig(config)) + .get(SNAPSHOTS_TAG); + + PAssert.that(result) + .satisfies(new VerifyOutputs(Collections.singletonList(identifier), "append")); + testPipeline.run().waitUntilFinish(); + + Pipeline p = Pipeline.create(TestPipeline.testingPipelineOptions()); + PCollection readRows = + p.apply(Managed.read(Managed.ICEBERG).withConfig(config)).getSinglePCollection(); + PAssert.that(readRows).containsInAnyOrder(rows); + p.run(); + } } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java index 2bce390e0992..5168f71fef99 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java @@ -27,9 +27,14 @@ import static org.junit.Assert.assertTrue; import java.io.IOException; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; @@ -39,6 +44,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Table; @@ -46,6 +52,8 @@ import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.hadoop.HadoopCatalog; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; import org.junit.Before; import org.junit.ClassRule; import org.junit.Rule; @@ -85,9 +93,14 @@ public void setUp() { private WindowedValue getWindowedDestination( String tableName, @Nullable PartitionSpec partitionSpec) { + return getWindowedDestination(tableName, ICEBERG_SCHEMA, partitionSpec); + } + + private WindowedValue getWindowedDestination( + String tableName, org.apache.iceberg.Schema schema, @Nullable PartitionSpec partitionSpec) { TableIdentifier tableIdentifier = TableIdentifier.of("default", tableName); - warehouse.createTable(tableIdentifier, ICEBERG_SCHEMA, partitionSpec); + warehouse.createTable(tableIdentifier, schema, partitionSpec); IcebergDestination icebergDestination = IcebergDestination.builder() @@ -314,8 +327,15 @@ public void testSerializableDataFileRoundTripEquality() throws IOException { DataFile datafile = writer.getDataFile(); assertEquals(2L, datafile.recordCount()); + Map partitionFieldMap = new HashMap<>(); + for (PartitionField partitionField : PARTITION_SPEC.fields()) { + partitionFieldMap.put(partitionField.name(), partitionField); + } + + String partitionPath = + RecordWriterManager.getPartitionDataPath(partitionKey.toPath(), partitionFieldMap); DataFile roundTripDataFile = - SerializableDataFile.from(datafile, partitionKey) + SerializableDataFile.from(datafile, partitionPath) .createDataFile(ImmutableMap.of(PARTITION_SPEC.specId(), PARTITION_SPEC)); checkDataFileEquality(datafile, roundTripDataFile); @@ -347,8 +367,14 @@ public void testRecreateSerializableDataAfterUpdatingPartitionSpec() throws IOEx writer.close(); // fetch data file and its serializable version + Map partitionFieldMap = new HashMap<>(); + for (PartitionField partitionField : PARTITION_SPEC.fields()) { + partitionFieldMap.put(partitionField.name(), partitionField); + } + String partitionPath = + RecordWriterManager.getPartitionDataPath(partitionKey.toPath(), partitionFieldMap); DataFile datafile = writer.getDataFile(); - SerializableDataFile serializableDataFile = SerializableDataFile.from(datafile, partitionKey); + SerializableDataFile serializableDataFile = SerializableDataFile.from(datafile, partitionPath); assertEquals(2L, datafile.recordCount()); assertEquals(serializableDataFile.getPartitionSpecId(), datafile.specId()); @@ -415,6 +441,198 @@ public void testWriterKeepsUpWithUpdatingPartitionSpec() throws IOException { } } + @Test + public void testIdentityPartitioning() throws IOException { + Schema primitiveTypeSchema = + Schema.builder() + .addBooleanField("bool") + .addInt32Field("int") + .addInt64Field("long") + .addFloatField("float") + .addDoubleField("double") + .addStringField("str") + .build(); + + Row row = + Row.withSchema(primitiveTypeSchema).addValues(true, 1, 1L, 1.23f, 4.56, "str").build(); + org.apache.iceberg.Schema icebergSchema = + IcebergUtils.beamSchemaToIcebergSchema(primitiveTypeSchema); + PartitionSpec spec = + PartitionSpec.builderFor(icebergSchema) + .identity("bool") + .identity("int") + .identity("long") + .identity("float") + .identity("double") + .identity("str") + .build(); + WindowedValue dest = + getWindowedDestination("identity_partitioning", icebergSchema, spec); + + RecordWriterManager writer = + new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, Integer.MAX_VALUE); + writer.write(dest, row); + writer.close(); + List files = writer.getSerializableDataFiles().get(dest); + assertEquals(1, files.size()); + SerializableDataFile dataFile = files.get(0); + assertEquals(1, dataFile.getRecordCount()); + // build this string: bool=true/int=1/long=1/float=1.0/double=1.0/str=str + List expectedPartitions = new ArrayList<>(); + for (Schema.Field field : primitiveTypeSchema.getFields()) { + Object val = row.getValue(field.getName()); + expectedPartitions.add(field.getName() + "=" + val); + } + String expectedPartitionPath = String.join("/", expectedPartitions); + assertEquals(expectedPartitionPath, dataFile.getPartitionPath()); + assertThat(dataFile.getPath(), containsString(expectedPartitionPath)); + } + + @Test + public void testBucketPartitioning() throws IOException { + Schema bucketSchema = + Schema.builder() + .addInt32Field("int") + .addInt64Field("long") + .addStringField("str") + .addLogicalTypeField("date", SqlTypes.DATE) + .addLogicalTypeField("time", SqlTypes.TIME) + .addLogicalTypeField("datetime", SqlTypes.DATETIME) + .addDateTimeField("datetime_tz") + .build(); + + String timestamp = "2024-10-08T13:18:20.053"; + LocalDateTime localDateTime = LocalDateTime.parse(timestamp); + + Row row = + Row.withSchema(bucketSchema) + .addValues( + 1, + 1L, + "str", + localDateTime.toLocalDate(), + localDateTime.toLocalTime(), + localDateTime, + DateTime.parse(timestamp)) + .build(); + org.apache.iceberg.Schema icebergSchema = IcebergUtils.beamSchemaToIcebergSchema(bucketSchema); + PartitionSpec spec = + PartitionSpec.builderFor(icebergSchema) + .bucket("int", 2) + .bucket("long", 2) + .bucket("str", 2) + .bucket("date", 2) + .bucket("time", 2) + .bucket("datetime", 2) + .bucket("datetime_tz", 2) + .build(); + WindowedValue dest = + getWindowedDestination("bucket_partitioning", icebergSchema, spec); + + RecordWriterManager writer = + new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, Integer.MAX_VALUE); + writer.write(dest, row); + writer.close(); + List files = writer.getSerializableDataFiles().get(dest); + assertEquals(1, files.size()); + SerializableDataFile dataFile = files.get(0); + assertEquals(1, dataFile.getRecordCount()); + for (Schema.Field field : bucketSchema.getFields()) { + String expectedPartition = field.getName() + "_bucket"; + assertThat(dataFile.getPartitionPath(), containsString(expectedPartition)); + assertThat(dataFile.getPath(), containsString(expectedPartition)); + } + } + + @Test + public void testTimePartitioning() throws IOException { + Schema timePartitioningSchema = + Schema.builder() + .addLogicalTypeField("y_date", SqlTypes.DATE) + .addLogicalTypeField("y_datetime", SqlTypes.DATETIME) + .addDateTimeField("y_datetime_tz") + .addLogicalTypeField("m_date", SqlTypes.DATE) + .addLogicalTypeField("m_datetime", SqlTypes.DATETIME) + .addDateTimeField("m_datetime_tz") + .addLogicalTypeField("d_date", SqlTypes.DATE) + .addLogicalTypeField("d_datetime", SqlTypes.DATETIME) + .addDateTimeField("d_datetime_tz") + .addLogicalTypeField("h_datetime", SqlTypes.DATETIME) + .addDateTimeField("h_datetime_tz") + .build(); + org.apache.iceberg.Schema icebergSchema = + IcebergUtils.beamSchemaToIcebergSchema(timePartitioningSchema); + PartitionSpec spec = + PartitionSpec.builderFor(icebergSchema) + .year("y_date") + .year("y_datetime") + .year("y_datetime_tz") + .month("m_date") + .month("m_datetime") + .month("m_datetime_tz") + .day("d_date") + .day("d_datetime") + .day("d_datetime_tz") + .hour("h_datetime") + .hour("h_datetime_tz") + .build(); + + WindowedValue dest = + getWindowedDestination("time_partitioning", icebergSchema, spec); + + String timestamp = "2024-10-08T13:18:20.053"; + LocalDateTime localDateTime = LocalDateTime.parse(timestamp); + LocalDate localDate = localDateTime.toLocalDate(); + String timestamptz = "2024-10-08T13:18:20.053+03:27"; + DateTime dateTime = DateTime.parse(timestamptz); + + Row row = + Row.withSchema(timePartitioningSchema) + .addValues(localDate, localDateTime, dateTime) // year + .addValues(localDate, localDateTime, dateTime) // month + .addValues(localDate, localDateTime, dateTime) // day + .addValues(localDateTime, dateTime) // hour + .build(); + + // write some rows + RecordWriterManager writer = + new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, Integer.MAX_VALUE); + writer.write(dest, row); + writer.close(); + List files = writer.getSerializableDataFiles().get(dest); + assertEquals(1, files.size()); + SerializableDataFile serializableDataFile = files.get(0); + assertEquals(1, serializableDataFile.getRecordCount()); + + int year = localDateTime.getYear(); + int month = localDateTime.getMonthValue(); + int day = localDateTime.getDayOfMonth(); + int hour = localDateTime.getHour(); + List expectedPartitions = new ArrayList<>(); + for (Schema.Field field : timePartitioningSchema.getFields()) { + String name = field.getName(); + String expected = ""; + if (name.startsWith("y_")) { + expected = String.format("%s_year=%s", name, year); + } else if (name.startsWith("m_")) { + expected = String.format("%s_month=%s-%02d", name, year, month); + } else if (name.startsWith("d_")) { + expected = String.format("%s_day=%s-%02d-%02d", name, year, month, day); + } else if (name.startsWith("h_")) { + if (name.contains("tz")) { + hour = dateTime.withZone(DateTimeZone.UTC).getHourOfDay(); + } + expected = String.format("%s_hour=%s-%02d-%02d-%02d", name, year, month, day, hour); + } + expectedPartitions.add(expected); + } + String expectedPartition = String.join("/", expectedPartitions); + DataFile dataFile = + serializableDataFile.createDataFile( + catalog.loadTable(dest.getValue().getTableIdentifier()).specs()); + assertThat(dataFile.path().toString(), containsString(expectedPartition)); + } + @Rule public ExpectedException thrown = ExpectedException.none(); @Test diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java index 1e1c84d31de9..9352123b5c77 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java @@ -32,6 +32,7 @@ import org.apache.iceberg.FileFormat; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.Namespace; @@ -108,6 +109,16 @@ protected void after() { public DataFile writeRecords(String filename, Schema schema, List records) throws IOException { + return writeRecords(filename, schema, PartitionSpec.unpartitioned(), null, records); + } + + public DataFile writeRecords( + String filename, + Schema schema, + PartitionSpec spec, + StructLike partition, + List records) + throws IOException { Path path = new Path(location, filename); FileFormat format = FileFormat.fromFileName(filename); @@ -134,9 +145,11 @@ public DataFile writeRecords(String filename, Schema schema, List record } appender.addAll(records); appender.close(); - return DataFiles.builder(PartitionSpec.unpartitioned()) + + return DataFiles.builder(spec) .withInputFile(HadoopInputFile.fromPath(path, hadoopConf)) .withMetrics(appender.metrics()) + .withPartition(partition) .build(); } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestFixtures.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestFixtures.java index 6143bd03491d..a2ca86d1b5a2 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestFixtures.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestFixtures.java @@ -21,11 +21,13 @@ import static org.apache.iceberg.types.Types.NestedField.required; import java.util.ArrayList; +import java.util.List; +import java.util.Map; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.iceberg.Schema; -import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.Record; import org.apache.iceberg.types.Types; @@ -34,58 +36,75 @@ public class TestFixtures { new Schema( required(1, "id", Types.LongType.get()), optional(2, "data", Types.StringType.get())); - private static final Record genericRecord = GenericRecord.create(SCHEMA); - - /* First file in test table */ - public static final ImmutableList FILE1SNAPSHOT1 = + public static final List> FILE1SNAPSHOT1_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 0L, "data", "clarification")), - genericRecord.copy(ImmutableMap.of("id", 1L, "data", "risky")), - genericRecord.copy(ImmutableMap.of("id", 2L, "data", "falafel"))); - public static final ImmutableList FILE1SNAPSHOT2 = + ImmutableMap.of("id", 0L, "data", "clarification"), + ImmutableMap.of("id", 1L, "data", "risky"), + ImmutableMap.of("id", 2L, "data", "falafel")); + public static final List> FILE1SNAPSHOT2_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 3L, "data", "obscure")), - genericRecord.copy(ImmutableMap.of("id", 4L, "data", "secure")), - genericRecord.copy(ImmutableMap.of("id", 5L, "data", "feta"))); - public static final ImmutableList FILE1SNAPSHOT3 = + ImmutableMap.of("id", 3L, "data", "obscure"), + ImmutableMap.of("id", 4L, "data", "secure"), + ImmutableMap.of("id", 5L, "data", "feta")); + public static final List> FILE1SNAPSHOT3_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 6L, "data", "brainy")), - genericRecord.copy(ImmutableMap.of("id", 7L, "data", "film")), - genericRecord.copy(ImmutableMap.of("id", 8L, "data", "feta"))); - - /* Second file in test table */ - public static final ImmutableList FILE2SNAPSHOT1 = + ImmutableMap.of("id", 6L, "data", "brainy"), + ImmutableMap.of("id", 7L, "data", "film"), + ImmutableMap.of("id", 8L, "data", "feta")); + public static final List> FILE2SNAPSHOT1_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 10L, "data", "clammy")), - genericRecord.copy(ImmutableMap.of("id", 11L, "data", "evacuate")), - genericRecord.copy(ImmutableMap.of("id", 12L, "data", "tissue"))); - public static final ImmutableList FILE2SNAPSHOT2 = + ImmutableMap.of("id", 10L, "data", "clammy"), + ImmutableMap.of("id", 11L, "data", "evacuate"), + ImmutableMap.of("id", 12L, "data", "tissue")); + public static final List> FILE2SNAPSHOT2_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 14L, "data", "radical")), - genericRecord.copy(ImmutableMap.of("id", 15L, "data", "collocation")), - genericRecord.copy(ImmutableMap.of("id", 16L, "data", "book"))); - public static final ImmutableList FILE2SNAPSHOT3 = + ImmutableMap.of("id", 14L, "data", "radical"), + ImmutableMap.of("id", 15L, "data", "collocation"), + ImmutableMap.of("id", 16L, "data", "book")); + public static final List> FILE2SNAPSHOT3_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 16L, "data", "cake")), - genericRecord.copy(ImmutableMap.of("id", 17L, "data", "intrinsic")), - genericRecord.copy(ImmutableMap.of("id", 18L, "data", "paper"))); - - /* Third file in test table */ - public static final ImmutableList FILE3SNAPSHOT1 = + ImmutableMap.of("id", 16L, "data", "cake"), + ImmutableMap.of("id", 17L, "data", "intrinsic"), + ImmutableMap.of("id", 18L, "data", "paper")); + public static final List> FILE3SNAPSHOT1_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 20L, "data", "ocean")), - genericRecord.copy(ImmutableMap.of("id", 21L, "data", "holistic")), - genericRecord.copy(ImmutableMap.of("id", 22L, "data", "preventative"))); - public static final ImmutableList FILE3SNAPSHOT2 = + ImmutableMap.of("id", 20L, "data", "ocean"), + ImmutableMap.of("id", 21L, "data", "holistic"), + ImmutableMap.of("id", 22L, "data", "preventative")); + public static final List> FILE3SNAPSHOT2_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 24L, "data", "cloud")), - genericRecord.copy(ImmutableMap.of("id", 25L, "data", "zen")), - genericRecord.copy(ImmutableMap.of("id", 26L, "data", "sky"))); - public static final ImmutableList FILE3SNAPSHOT3 = + ImmutableMap.of("id", 24L, "data", "cloud"), + ImmutableMap.of("id", 25L, "data", "zen"), + ImmutableMap.of("id", 26L, "data", "sky")); + public static final List> FILE3SNAPSHOT3_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 26L, "data", "belleview")), - genericRecord.copy(ImmutableMap.of("id", 27L, "data", "overview")), - genericRecord.copy(ImmutableMap.of("id", 28L, "data", "tender"))); + ImmutableMap.of("id", 26L, "data", "belleview"), + ImmutableMap.of("id", 27L, "data", "overview"), + ImmutableMap.of("id", 28L, "data", "tender")); + + /* First file in test table */ + public static final List FILE1SNAPSHOT1 = + Lists.transform(FILE1SNAPSHOT1_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE1SNAPSHOT2 = + Lists.transform(FILE1SNAPSHOT2_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE1SNAPSHOT3 = + Lists.transform(FILE1SNAPSHOT3_DATA, d -> createRecord(SCHEMA, d)); + + /* Second file in test table */ + public static final List FILE2SNAPSHOT1 = + Lists.transform(FILE2SNAPSHOT1_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE2SNAPSHOT2 = + Lists.transform(FILE2SNAPSHOT2_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE2SNAPSHOT3 = + Lists.transform(FILE2SNAPSHOT3_DATA, d -> createRecord(SCHEMA, d)); + + /* Third file in test table */ + public static final List FILE3SNAPSHOT1 = + Lists.transform(FILE3SNAPSHOT1_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE3SNAPSHOT2 = + Lists.transform(FILE3SNAPSHOT2_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE3SNAPSHOT3 = + Lists.transform(FILE3SNAPSHOT3_DATA, d -> createRecord(SCHEMA, d)); public static final ImmutableList asRows(Iterable records) { ArrayList rows = new ArrayList<>(); @@ -98,4 +117,8 @@ public static final ImmutableList asRows(Iterable records) { } return ImmutableList.copyOf(rows); } + + public static Record createRecord(org.apache.iceberg.Schema schema, Map values) { + return org.apache.iceberg.data.GenericRecord.create(schema).copy(values); + } } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/HadoopCatalogIT.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/HadoopCatalogIT.java new file mode 100644 index 000000000000..d33a372e5e3b --- /dev/null +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/HadoopCatalogIT.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.iceberg.catalog; + +import java.io.IOException; +import java.util.Map; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.hadoop.HadoopCatalog; + +public class HadoopCatalogIT extends IcebergCatalogBaseIT { + @Override + public Integer numRecords() { + return 100; + } + + @Override + public Catalog createCatalog() { + Configuration catalogHadoopConf = new Configuration(); + catalogHadoopConf.set("fs.gs.project.id", options.getProject()); + catalogHadoopConf.set("fs.gs.auth.type", "APPLICATION_DEFAULT"); + + HadoopCatalog catalog = new HadoopCatalog(); + catalog.setConf(catalogHadoopConf); + catalog.initialize("hadoop_" + catalogName, ImmutableMap.of("warehouse", warehouse)); + + return catalog; + } + + @Override + public void catalogCleanup() throws IOException { + ((HadoopCatalog) catalog).close(); + } + + @Override + public Map managedIcebergConfig(String tableId) { + return ImmutableMap.builder() + .put("table", tableId) + .put( + "catalog_properties", + ImmutableMap.builder() + .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) + .put("warehouse", warehouse) + .build()) + .build(); + } +} diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/HiveCatalogIT.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/HiveCatalogIT.java new file mode 100644 index 000000000000..f31eb19906ff --- /dev/null +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/HiveCatalogIT.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.iceberg.catalog; + +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.beam.sdk.io.iceberg.catalog.hiveutils.HiveMetastoreExtension; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.Database; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.hive.HiveCatalog; + +/** + * Read and write tests using {@link HiveCatalog}. + * + *

Spins up a local Hive metastore to manage the Iceberg table. Warehouse path is set to a GCS + * bucket. + */ +public class HiveCatalogIT extends IcebergCatalogBaseIT { + private static HiveMetastoreExtension hiveMetastoreExtension; + private static final String TEST_DB = "test_db"; + + @Override + public String tableId() { + return String.format("%s.%s", TEST_DB, testName.getMethodName()); + } + + @Override + public void catalogSetup() throws Exception { + hiveMetastoreExtension = new HiveMetastoreExtension(warehouse); + String dbPath = hiveMetastoreExtension.metastore().getDatabasePath(TEST_DB); + Database db = new Database(TEST_DB, "description", dbPath, Maps.newHashMap()); + hiveMetastoreExtension.metastoreClient().createDatabase(db); + } + + @Override + public Catalog createCatalog() { + return CatalogUtil.loadCatalog( + HiveCatalog.class.getName(), + "hive_" + catalogName, + ImmutableMap.of( + CatalogProperties.CLIENT_POOL_CACHE_EVICTION_INTERVAL_MS, + String.valueOf(TimeUnit.SECONDS.toMillis(10))), + hiveMetastoreExtension.hiveConf()); + } + + @Override + public void catalogCleanup() throws Exception { + System.out.println("xxx CLEANING UP!"); + if (hiveMetastoreExtension != null) { + hiveMetastoreExtension.cleanup(); + } + } + + @Override + public Map managedIcebergConfig(String tableId) { + String metastoreUri = hiveMetastoreExtension.hiveConf().getVar(HiveConf.ConfVars.METASTOREURIS); + + Map confProperties = + ImmutableMap.builder() + .put(HiveConf.ConfVars.METASTOREURIS.varname, metastoreUri) + .build(); + + return ImmutableMap.builder() + .put("table", tableId) + .put("name", "hive_" + catalogName) + .put("config_properties", confProperties) + .build(); + } +} diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/IcebergCatalogBaseIT.java similarity index 77% rename from sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java rename to sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/IcebergCatalogBaseIT.java index c79b0a550051..8e4a74cd61d4 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/IcebergCatalogBaseIT.java @@ -15,9 +15,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.iceberg; +package org.apache.beam.sdk.io.iceberg.catalog; -import static org.apache.beam.sdk.schemas.Schema.FieldType; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -40,7 +39,9 @@ import org.apache.beam.sdk.extensions.gcp.options.GcsOptions; import org.apache.beam.sdk.extensions.gcp.util.GcsUtil; import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; +import org.apache.beam.sdk.io.iceberg.IcebergUtils; import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -56,14 +57,10 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.AppendFiles; -import org.apache.iceberg.CatalogUtil; import org.apache.iceberg.CombinedScanTask; import org.apache.iceberg.FileScanTask; import org.apache.iceberg.PartitionSpec; -import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.TableScan; import org.apache.iceberg.catalog.Catalog; @@ -72,7 +69,6 @@ import org.apache.iceberg.data.parquet.GenericParquetReaders; import org.apache.iceberg.data.parquet.GenericParquetWriter; import org.apache.iceberg.encryption.InputFilesDecryptor; -import org.apache.iceberg.hadoop.HadoopCatalog; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.DataWriter; import org.apache.iceberg.io.InputFile; @@ -84,44 +80,123 @@ import org.joda.time.DateTimeZone; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.AfterClass; +import org.junit.After; import org.junit.Before; -import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** Integration tests for {@link IcebergIO} source and sink. */ -@RunWith(JUnit4.class) -public class IcebergIOIT implements Serializable { - private static final Logger LOG = LoggerFactory.getLogger(IcebergIOIT.class); +/** + * Base class for {@link Managed} {@link org.apache.beam.sdk.io.iceberg.IcebergIO} read and write + * tests. + * + *

To test a new catalog, create a subclass of this test class and implement the following two + * methods: + * + *

    + *
  • {@link #createCatalog()} + *
  • {@link #managedIcebergConfig(String)} + *
+ * + *

If the catalog needs further logic to set up and tear down, you can override and implement + * these methods: + * + *

    + *
  • {@link #catalogSetup()} + *
  • {@link #catalogCleanup()} + *
+ * + *

1,000 records are used for each test by default. You can change this by overriding {@link + * #numRecords()}. + */ +public abstract class IcebergCatalogBaseIT implements Serializable { + public abstract Catalog createCatalog(); + + public abstract Map managedIcebergConfig(String tableId); + + public void catalogSetup() throws Exception {} + + public void catalogCleanup() throws Exception {} + + public Integer numRecords() { + return 1000; + } + + public String tableId() { + return testName.getMethodName() + ".test_table"; + } + + public String catalogName = "test_catalog_" + System.nanoTime(); + + @Before + public void setUp() throws Exception { + options = TestPipeline.testingPipelineOptions().as(GcpOptions.class); + warehouse = + String.format( + "%s/%s/%s", + TestPipeline.testingPipelineOptions().getTempLocation(), + getClass().getSimpleName(), + RANDOM); + catalogSetup(); + catalog = createCatalog(); + } - private static final org.apache.beam.sdk.schemas.Schema DOUBLY_NESTED_ROW_SCHEMA = - org.apache.beam.sdk.schemas.Schema.builder() + @After + public void cleanUp() throws Exception { + catalogCleanup(); + + try { + GcsUtil gcsUtil = options.as(GcsOptions.class).getGcsUtil(); + GcsPath path = GcsPath.fromUri(warehouse); + + Objects objects = + gcsUtil.listObjects( + path.getBucket(), + getClass().getSimpleName() + "/" + path.getFileName().toString(), + null); + List filesToDelete = + objects.getItems().stream() + .map(obj -> "gs://" + path.getBucket() + "/" + obj.getName()) + .collect(Collectors.toList()); + + gcsUtil.remove(filesToDelete); + } catch (Exception e) { + LOG.warn("Failed to clean up files.", e); + } + } + + protected static String warehouse; + public Catalog catalog; + protected GcpOptions options; + private static final String RANDOM = UUID.randomUUID().toString(); + @Rule public TestPipeline pipeline = TestPipeline.create(); + @Rule public TestName testName = new TestName(); + private static final int NUM_SHARDS = 10; + private static final Logger LOG = LoggerFactory.getLogger(IcebergCatalogBaseIT.class); + private static final Schema DOUBLY_NESTED_ROW_SCHEMA = + Schema.builder() .addStringField("doubly_nested_str") .addInt64Field("doubly_nested_float") .build(); - private static final org.apache.beam.sdk.schemas.Schema NESTED_ROW_SCHEMA = - org.apache.beam.sdk.schemas.Schema.builder() + private static final Schema NESTED_ROW_SCHEMA = + Schema.builder() .addStringField("nested_str") .addRowField("nested_row", DOUBLY_NESTED_ROW_SCHEMA) .addInt32Field("nested_int") .addFloatField("nested_float") .build(); - private static final org.apache.beam.sdk.schemas.Schema BEAM_SCHEMA = - org.apache.beam.sdk.schemas.Schema.builder() + private static final Schema BEAM_SCHEMA = + Schema.builder() .addStringField("str") .addStringField("char") .addInt64Field("modulo_5") .addBooleanField("bool") .addInt32Field("int") .addRowField("row", NESTED_ROW_SCHEMA) - .addArrayField("arr_long", FieldType.INT64) + .addArrayField("arr_long", Schema.FieldType.INT64) .addNullableRowField("nullable_row", NESTED_ROW_SCHEMA) .addNullableInt64Field("nullable_long") .addDateTimeField("datetime_tz") @@ -174,65 +249,16 @@ public Record apply(Row input) { return IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, input); } }; - private static final Integer NUM_RECORDS = 1000; - private static final Integer NUM_SHARDS = 10; - - @Rule public TestPipeline pipeline = TestPipeline.create(); - - static GcpOptions options; - - static Configuration catalogHadoopConf; - - @Rule public TestName testName = new TestName(); - - private static String warehouseLocation; - - private String tableId; - private static Catalog catalog; - - @BeforeClass - public static void beforeClass() { - options = TestPipeline.testingPipelineOptions().as(GcpOptions.class); - warehouseLocation = - String.format("%s/IcebergIOIT/%s", options.getTempLocation(), UUID.randomUUID()); - catalogHadoopConf = new Configuration(); - catalogHadoopConf.set("fs.gs.project.id", options.getProject()); - catalogHadoopConf.set("fs.gs.auth.type", "APPLICATION_DEFAULT"); - catalog = new HadoopCatalog(catalogHadoopConf, warehouseLocation); - } - - @Before - public void setUp() { - tableId = testName.getMethodName() + ".test_table"; - } - - @AfterClass - public static void afterClass() { - try { - GcsUtil gcsUtil = options.as(GcsOptions.class).getGcsUtil(); - GcsPath path = GcsPath.fromUri(warehouseLocation); - - Objects objects = - gcsUtil.listObjects( - path.getBucket(), "IcebergIOIT/" + path.getFileName().toString(), null); - List filesToDelete = - objects.getItems().stream() - .map(obj -> "gs://" + path.getBucket() + "/" + obj.getName()) - .collect(Collectors.toList()); - - gcsUtil.remove(filesToDelete); - } catch (Exception e) { - LOG.warn("Failed to clean up files.", e); - } - } + private final List inputRows = + LongStream.range(0, numRecords()).boxed().map(ROW_FUNC::apply).collect(Collectors.toList()); /** Populates the Iceberg table and Returns a {@link List} of expected elements. */ private List populateTable(Table table) throws IOException { - double recordsPerShardFraction = NUM_RECORDS.doubleValue() / NUM_SHARDS; + double recordsPerShardFraction = numRecords().doubleValue() / NUM_SHARDS; long maxRecordsPerShard = Math.round(Math.ceil(recordsPerShardFraction)); AppendFiles appendFiles = table.newAppend(); - List expectedRows = new ArrayList<>(NUM_RECORDS); + List expectedRows = new ArrayList<>(numRecords()); int totalRecords = 0; for (int shardNum = 0; shardNum < NUM_SHARDS; ++shardNum) { String filepath = table.location() + "/" + UUID.randomUUID(); @@ -246,7 +272,7 @@ private List populateTable(Table table) throws IOException { .build(); for (int recordNum = 0; - recordNum < maxRecordsPerShard && totalRecords < NUM_RECORDS; + recordNum < maxRecordsPerShard && totalRecords < numRecords(); ++recordNum, ++totalRecords) { Row expectedBeamRow = ROW_FUNC.apply((long) recordNum); @@ -264,7 +290,7 @@ private List populateTable(Table table) throws IOException { } private List readRecords(Table table) { - Schema tableSchema = table.schema(); + org.apache.iceberg.Schema tableSchema = table.schema(); TableScan tableScan = table.newScan().project(tableSchema); List writtenRecords = new ArrayList<>(); for (CombinedScanTask task : tableScan.planTasks()) { @@ -289,31 +315,13 @@ private List readRecords(Table table) { return writtenRecords; } - private Map managedIcebergConfig(String tableId) { - return ImmutableMap.builder() - .put("table", tableId) - .put("catalog_name", "test-name") - .put( - "catalog_properties", - ImmutableMap.builder() - .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .put("warehouse", warehouseLocation) - .build()) - .build(); - } - - /** - * Test of a predetermined moderate number of records written directly to Iceberg then read via a - * Beam pipeline. Table initialization is done on a single process using the Iceberg APIs so the - * data cannot be "big". - */ @Test public void testRead() throws Exception { - Table table = catalog.createTable(TableIdentifier.parse(tableId), ICEBERG_SCHEMA); + Table table = catalog.createTable(TableIdentifier.parse(tableId()), ICEBERG_SCHEMA); List expectedRows = populateTable(table); - Map config = managedIcebergConfig(tableId); + Map config = managedIcebergConfig(tableId()); PCollection rows = pipeline.apply(Managed.read(Managed.ICEBERG).withConfig(config)).getSinglePCollection(); @@ -322,70 +330,64 @@ public void testRead() throws Exception { pipeline.run().waitUntilFinish(); } - private static final List INPUT_ROWS = - LongStream.range(0, NUM_RECORDS).boxed().map(ROW_FUNC::apply).collect(Collectors.toList()); - - /** - * Test of a predetermined moderate number of records written to Iceberg using a Beam pipeline, - * then read directly using Iceberg API. - */ @Test public void testWrite() { // Write with Beam // Expect the sink to create the table - Map config = managedIcebergConfig(tableId); - PCollection input = pipeline.apply(Create.of(INPUT_ROWS)).setRowSchema(BEAM_SCHEMA); + Map config = managedIcebergConfig(tableId()); + PCollection input = pipeline.apply(Create.of(inputRows)).setRowSchema(BEAM_SCHEMA); input.apply(Managed.write(Managed.ICEBERG).withConfig(config)); pipeline.run().waitUntilFinish(); - Table table = catalog.loadTable(TableIdentifier.parse(tableId)); + Table table = catalog.loadTable(TableIdentifier.parse(tableId())); assertTrue(table.schema().sameSchema(ICEBERG_SCHEMA)); // Read back and check records are correct List returnedRecords = readRecords(table); assertThat( - returnedRecords, containsInAnyOrder(INPUT_ROWS.stream().map(RECORD_FUNC::apply).toArray())); + returnedRecords, containsInAnyOrder(inputRows.stream().map(RECORD_FUNC::apply).toArray())); } @Test - public void testWritePartitionedData() { + public void testWriteToPartitionedTable() { // For an example row where bool=true, modulo_5=3, str=value_303, // this partition spec will create a partition like: /bool=true/modulo_5=3/str_trunc=value_3/ PartitionSpec partitionSpec = PartitionSpec.builderFor(ICEBERG_SCHEMA) .identity("bool") - .identity("modulo_5") + .hour("datetime") .truncate("str", "value_x".length()) .build(); Table table = - catalog.createTable(TableIdentifier.parse(tableId), ICEBERG_SCHEMA, partitionSpec); + catalog.createTable(TableIdentifier.parse(tableId()), ICEBERG_SCHEMA, partitionSpec); // Write with Beam - Map config = managedIcebergConfig(tableId); - PCollection input = pipeline.apply(Create.of(INPUT_ROWS)).setRowSchema(BEAM_SCHEMA); + Map config = managedIcebergConfig(tableId()); + PCollection input = pipeline.apply(Create.of(inputRows)).setRowSchema(BEAM_SCHEMA); input.apply(Managed.write(Managed.ICEBERG).withConfig(config)); pipeline.run().waitUntilFinish(); // Read back and check records are correct List returnedRecords = readRecords(table); assertThat( - returnedRecords, containsInAnyOrder(INPUT_ROWS.stream().map(RECORD_FUNC::apply).toArray())); + returnedRecords, containsInAnyOrder(inputRows.stream().map(RECORD_FUNC::apply).toArray())); } private PeriodicImpulse getStreamingSource() { return PeriodicImpulse.create() - .stopAfter(Duration.millis(NUM_RECORDS - 1)) + .stopAfter(Duration.millis(numRecords() - 1)) .withInterval(Duration.millis(1)); } @Test public void testStreamingWrite() { + int numRecords = numRecords(); PartitionSpec partitionSpec = PartitionSpec.builderFor(ICEBERG_SCHEMA).identity("bool").identity("modulo_5").build(); Table table = - catalog.createTable(TableIdentifier.parse(tableId), ICEBERG_SCHEMA, partitionSpec); + catalog.createTable(TableIdentifier.parse(tableId()), ICEBERG_SCHEMA, partitionSpec); - Map config = new HashMap<>(managedIcebergConfig(tableId)); + Map config = new HashMap<>(managedIcebergConfig(tableId())); config.put("triggering_frequency_seconds", 4); // create elements from longs in range [0, 1000) @@ -394,7 +396,7 @@ public void testStreamingWrite() { .apply(getStreamingSource()) .apply( MapElements.into(TypeDescriptors.rows()) - .via(instant -> ROW_FUNC.apply(instant.getMillis() % NUM_RECORDS))) + .via(instant -> ROW_FUNC.apply(instant.getMillis() % numRecords))) .setRowSchema(BEAM_SCHEMA); assertThat(input.isBounded(), equalTo(PCollection.IsBounded.UNBOUNDED)); @@ -404,17 +406,18 @@ public void testStreamingWrite() { List returnedRecords = readRecords(table); assertThat( - returnedRecords, containsInAnyOrder(INPUT_ROWS.stream().map(RECORD_FUNC::apply).toArray())); + returnedRecords, containsInAnyOrder(inputRows.stream().map(RECORD_FUNC::apply).toArray())); } @Test public void testStreamingWriteWithPriorWindowing() { + int numRecords = numRecords(); PartitionSpec partitionSpec = PartitionSpec.builderFor(ICEBERG_SCHEMA).identity("bool").identity("modulo_5").build(); Table table = - catalog.createTable(TableIdentifier.parse(tableId), ICEBERG_SCHEMA, partitionSpec); + catalog.createTable(TableIdentifier.parse(tableId()), ICEBERG_SCHEMA, partitionSpec); - Map config = new HashMap<>(managedIcebergConfig(tableId)); + Map config = new HashMap<>(managedIcebergConfig(tableId())); config.put("triggering_frequency_seconds", 4); // over a span of 10 seconds, create elements from longs in range [0, 1000) @@ -426,7 +429,7 @@ public void testStreamingWriteWithPriorWindowing() { .accumulatingFiredPanes()) .apply( MapElements.into(TypeDescriptors.rows()) - .via(instant -> ROW_FUNC.apply(instant.getMillis() % NUM_RECORDS))) + .via(instant -> ROW_FUNC.apply(instant.getMillis() % numRecords))) .setRowSchema(BEAM_SCHEMA); assertThat(input.isBounded(), equalTo(PCollection.IsBounded.UNBOUNDED)); @@ -436,7 +439,7 @@ public void testStreamingWriteWithPriorWindowing() { List returnedRecords = readRecords(table); assertThat( - returnedRecords, containsInAnyOrder(INPUT_ROWS.stream().map(RECORD_FUNC::apply).toArray())); + returnedRecords, containsInAnyOrder(inputRows.stream().map(RECORD_FUNC::apply).toArray())); } private void writeToDynamicDestinations(@Nullable String filterOp) { @@ -450,7 +453,8 @@ private void writeToDynamicDestinations(@Nullable String filterOp) { */ private void writeToDynamicDestinations( @Nullable String filterOp, boolean streaming, boolean partitioning) { - String tableIdentifierTemplate = tableId + "_{modulo_5}_{char}"; + int numRecords = numRecords(); + String tableIdentifierTemplate = tableId() + "_{modulo_5}_{char}"; Map writeConfig = new HashMap<>(managedIcebergConfig(tableIdentifierTemplate)); List fieldsToFilter = Arrays.asList("row", "str", "int", "nullable_long"); @@ -475,13 +479,14 @@ private void writeToDynamicDestinations( } } - Schema tableSchema = IcebergUtils.beamSchemaToIcebergSchema(rowFilter.outputSchema()); + org.apache.iceberg.Schema tableSchema = + IcebergUtils.beamSchemaToIcebergSchema(rowFilter.outputSchema()); - TableIdentifier tableIdentifier0 = TableIdentifier.parse(tableId + "_0_a"); - TableIdentifier tableIdentifier1 = TableIdentifier.parse(tableId + "_1_b"); - TableIdentifier tableIdentifier2 = TableIdentifier.parse(tableId + "_2_c"); - TableIdentifier tableIdentifier3 = TableIdentifier.parse(tableId + "_3_d"); - TableIdentifier tableIdentifier4 = TableIdentifier.parse(tableId + "_4_e"); + TableIdentifier tableIdentifier0 = TableIdentifier.parse(tableId() + "_0_a"); + TableIdentifier tableIdentifier1 = TableIdentifier.parse(tableId() + "_1_b"); + TableIdentifier tableIdentifier2 = TableIdentifier.parse(tableId() + "_2_c"); + TableIdentifier tableIdentifier3 = TableIdentifier.parse(tableId() + "_3_d"); + TableIdentifier tableIdentifier4 = TableIdentifier.parse(tableId() + "_4_e"); // the sink doesn't support creating partitioned tables yet, // so we need to create it manually for this test case if (partitioning) { @@ -504,10 +509,11 @@ private void writeToDynamicDestinations( .apply(getStreamingSource()) .apply( MapElements.into(TypeDescriptors.rows()) - .via(instant -> ROW_FUNC.apply(instant.getMillis() % NUM_RECORDS))); + .via(instant -> ROW_FUNC.apply(instant.getMillis() % numRecords))); } else { - input = pipeline.apply(Create.of(INPUT_ROWS)); + input = pipeline.apply(Create.of(inputRows)); } + input.setRowSchema(BEAM_SCHEMA).apply(Managed.write(Managed.ICEBERG).withConfig(writeConfig)); pipeline.run().waitUntilFinish(); @@ -537,7 +543,7 @@ private void writeToDynamicDestinations( List records = returnedRecords.get(i); long l = i; Stream expectedRecords = - INPUT_ROWS.stream() + inputRows.stream() .filter(rec -> checkStateNotNull(rec.getInt64("modulo_5")) == l) .map(rowFilter::filter) .map(recordFunc::apply); @@ -556,11 +562,6 @@ public void testWriteToDynamicDestinationsAndDropFields() { writeToDynamicDestinations("drop"); } - @Test - public void testWriteToDynamicDestinationsAndKeepFields() { - writeToDynamicDestinations("keep"); - } - @Test public void testWriteToDynamicDestinationsWithOnlyRecord() { writeToDynamicDestinations("only"); diff --git a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/HiveMetastoreExtension.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/HiveMetastoreExtension.java similarity index 97% rename from sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/HiveMetastoreExtension.java rename to sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/HiveMetastoreExtension.java index 52de1b91a216..5ed05db27768 100644 --- a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/HiveMetastoreExtension.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/HiveMetastoreExtension.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.iceberg.hive.testutils; +package org.apache.beam.sdk.io.iceberg.catalog.hiveutils; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.metastore.HiveMetaStoreClient; diff --git a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/ScriptRunner.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/ScriptRunner.java similarity index 99% rename from sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/ScriptRunner.java rename to sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/ScriptRunner.java index adf941e00b4b..d77cf0bf74c7 100644 --- a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/ScriptRunner.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/ScriptRunner.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.iceberg.hive.testutils; +package org.apache.beam.sdk.io.iceberg.catalog.hiveutils; import java.io.IOException; import java.io.LineNumberReader; diff --git a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/TestHiveMetastore.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/TestHiveMetastore.java similarity index 91% rename from sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/TestHiveMetastore.java rename to sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/TestHiveMetastore.java index e3af43d58c65..94f519179e9d 100644 --- a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/testutils/TestHiveMetastore.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/catalog/hiveutils/TestHiveMetastore.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.iceberg.hive.testutils; +package org.apache.beam.sdk.io.iceberg.catalog.hiveutils; import static java.nio.file.Files.createTempDirectory; import static java.nio.file.attribute.PosixFilePermissions.asFileAttribute; @@ -33,6 +33,7 @@ import java.sql.SQLException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; @@ -63,7 +64,7 @@ public class TestHiveMetastore { private static final String DEFAULT_DATABASE_NAME = "default"; - private static final int DEFAULT_POOL_SIZE = 5; + private static final int DEFAULT_POOL_SIZE = 3; // create the metastore handlers based on whether we're working with Hive2 or Hive3 dependencies // we need to do this because there is a breaking API change between Hive2 and Hive3 @@ -79,18 +80,6 @@ public class TestHiveMetastore { .impl(RetryingHMSHandler.class, HiveConf.class, IHMSHandler.class, boolean.class) .buildStatic(); - // Hive3 introduces background metastore tasks (MetastoreTaskThread) for performing various - // cleanup duties. These - // threads are scheduled and executed in a static thread pool - // (org.apache.hadoop.hive.metastore.ThreadPool). - // This thread pool is shut down normally as part of the JVM shutdown hook, but since we're - // creating and tearing down - // multiple metastore instances within the same JVM, we have to call this cleanup method manually, - // otherwise - // threads from our previous test suite will be stuck in the pool with stale config, and keep on - // being scheduled. - // This can lead to issues, e.g. accidental Persistence Manager closure by - // ScheduledQueryExecutionsMaintTask. private static final DynMethods.StaticMethod METASTORE_THREADS_SHUTDOWN = DynMethods.builder("shutdown") .impl("org.apache.hadoop.hive.metastore.ThreadPool") @@ -140,8 +129,7 @@ public class TestHiveMetastore { } /** - * Starts a TestHiveMetastore with the default connection pool size (5) with the provided - * HiveConf. + * Starts a TestHiveMetastore with the default connection pool size with the provided HiveConf. * * @param conf The hive configuration to use */ @@ -181,7 +169,13 @@ public void stop() throws Exception { server.stop(); } if (executorService != null) { - executorService.shutdown(); + executorService.shutdownNow(); + try { + // Give it a reasonable timeout + executorService.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } } if (baseHandler != null) { baseHandler.shutdown(); diff --git a/sdks/java/io/iceberg/hive/src/test/resources/hive-schema-3.1.0.derby.sql b/sdks/java/io/iceberg/src/test/resources/hive-schema-3.1.0.derby.sql similarity index 100% rename from sdks/java/io/iceberg/hive/src/test/resources/hive-schema-3.1.0.derby.sql rename to sdks/java/io/iceberg/src/test/resources/hive-schema-3.1.0.derby.sql diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java index 435bfc138b5b..b4765f0392c1 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java @@ -28,6 +28,7 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; @@ -49,25 +50,131 @@ public class JdbcReadSchemaTransformProvider extends TypedSchemaTransformProvider< JdbcReadSchemaTransformProvider.JdbcReadSchemaTransformConfiguration> { + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:jdbc_read:v1"; + } + + @Override + public String description() { + return "Read from a JDBC source using a SQL query or by directly accessing a single table.\n" + + "\n" + + "This transform can be used to read from a JDBC source using either a given JDBC driver jar " + + "and class name, or by using one of the default packaged drivers given a `jdbc_type`.\n" + + "\n" + + "#### Using a default driver\n" + + "\n" + + "This transform comes packaged with drivers for several popular JDBC distributions. The following " + + "distributions can be declared as the `jdbc_type`: " + + JDBC_DRIVER_MAP.keySet().toString().replaceAll("[\\[\\]]", "") + + ".\n" + + "\n" + + "For example, reading a MySQL source using a SQL query: ::" + + "\n" + + " - type: ReadFromJdbc\n" + + " config:\n" + + " jdbc_type: mysql\n" + + " url: \"jdbc:mysql://my-host:3306/database\"\n" + + " query: \"SELECT * FROM table\"\n" + + "\n" + + "\n" + + "**Note**: See the following transforms which are built on top of this transform and simplify " + + "this logic for several popular JDBC distributions:\n\n" + + " - ReadFromMySql\n" + + " - ReadFromPostgres\n" + + " - ReadFromOracle\n" + + " - ReadFromSqlServer\n" + + "\n" + + "#### Declaring custom JDBC drivers\n" + + "\n" + + "If reading from a JDBC source not listed above, or if it is necessary to use a custom driver not " + + "packaged with Beam, one must define a JDBC driver and class name.\n" + + "\n" + + "For example, reading a MySQL source table: ::" + + "\n" + + " - type: ReadFromJdbc\n" + + " config:\n" + + " driver_jars: \"path/to/some/jdbc.jar\"\n" + + " driver_class_name: \"com.mysql.jdbc.Driver\"\n" + + " url: \"jdbc:mysql://my-host:3306/database\"\n" + + " table: \"my-table\"\n" + + "\n" + + "#### Connection Properties\n" + + "\n" + + "Connection properties are properties sent to the Driver used to connect to the JDBC source. For example, " + + "to set the character encoding to UTF-8, one could write: ::\n" + + "\n" + + " - type: ReadFromJdbc\n" + + " config:\n" + + " connectionProperties: \"characterEncoding=UTF-8;\"\n" + + " ...\n" + + "All properties should be semi-colon-delimited (e.g. \"key1=value1;key2=value2;\")\n"; + } + + protected String inheritedDescription( + String prettyName, String transformName, String databaseSchema, int defaultJdbcPort) { + return String.format( + "Read from a %s source using a SQL query or by directly accessing a single table.%n" + + "%n" + + "This is a special case of ReadFromJdbc that includes the " + + "necessary %s Driver and classes.%n" + + "%n" + + "An example of using %s with SQL query: ::%n" + + "%n" + + " - type: %s%n" + + " config:%n" + + " url: \"jdbc:%s://my-host:%d/database\"%n" + + " query: \"SELECT * FROM table\"%n" + + "%n" + + "It is also possible to read a table by specifying a table name. For example, the " + + "following configuration will perform a read on an entire table: ::%n" + + "%n" + + " - type: %s%n" + + " config:%n" + + " url: \"jdbc:%s://my-host:%d/database\"%n" + + " table: \"my-table\"%n" + + "%n" + + "#### Advanced Usage%n" + + "%n" + + "It might be necessary to use a custom JDBC driver that is not packaged with this " + + "transform. If that is the case, see ReadFromJdbc which " + + "allows for more custom configuration.", + prettyName, + prettyName, + transformName, + transformName, + databaseSchema, + defaultJdbcPort, + transformName, + databaseSchema, + defaultJdbcPort); + } + @Override protected @UnknownKeyFor @NonNull @Initialized Class configurationClass() { return JdbcReadSchemaTransformConfiguration.class; } + protected String jdbcType() { + return ""; + } + @Override protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from( JdbcReadSchemaTransformConfiguration configuration) { - configuration.validate(); - return new JdbcReadSchemaTransform(configuration); + configuration.validate(jdbcType()); + return new JdbcReadSchemaTransform(configuration, jdbcType()); } - static class JdbcReadSchemaTransform extends SchemaTransform implements Serializable { + protected static class JdbcReadSchemaTransform extends SchemaTransform implements Serializable { JdbcReadSchemaTransformConfiguration config; + private final String jdbcType; - public JdbcReadSchemaTransform(JdbcReadSchemaTransformConfiguration config) { + public JdbcReadSchemaTransform(JdbcReadSchemaTransformConfiguration config, String jdbcType) { this.config = config; + this.jdbcType = jdbcType; } protected JdbcIO.DataSourceConfiguration dataSourceConfiguration() { @@ -75,7 +182,10 @@ protected JdbcIO.DataSourceConfiguration dataSourceConfiguration() { if (Strings.isNullOrEmpty(driverClassName)) { driverClassName = - JDBC_DRIVER_MAP.get(Objects.requireNonNull(config.getJdbcType()).toLowerCase()); + JDBC_DRIVER_MAP.get( + (Objects.requireNonNull( + !Strings.isNullOrEmpty(jdbcType) ? jdbcType : config.getJdbcType())) + .toLowerCase()); } JdbcIO.DataSourceConfiguration dsConfig = @@ -109,7 +219,7 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } JdbcIO.ReadRows readRows = JdbcIO.readRows().withDataSourceConfiguration(dataSourceConfiguration()).withQuery(query); - Short fetchSize = config.getFetchSize(); + Integer fetchSize = config.getFetchSize(); if (fetchSize != null && fetchSize > 0) { readRows = readRows.withFetchSize(fetchSize); } @@ -125,11 +235,6 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } } - @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:jdbc_read:v1"; - } - @Override public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> inputCollectionNames() { @@ -145,62 +250,91 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { @AutoValue @DefaultSchema(AutoValueSchema.class) public abstract static class JdbcReadSchemaTransformConfiguration implements Serializable { - @Nullable - public abstract String getDriverClassName(); + @SchemaFieldDescription("Connection URL for the JDBC source.") + public abstract String getJdbcUrl(); + + @SchemaFieldDescription( + "Sets the connection init sql statements used by the Driver. Only MySQL and MariaDB support this.") @Nullable - public abstract String getJdbcType(); + public abstract List<@org.checkerframework.checker.nullness.qual.Nullable String> + getConnectionInitSql(); - public abstract String getJdbcUrl(); + @SchemaFieldDescription( + "Used to set connection properties passed to the JDBC driver not already defined as standalone parameter (e.g. username and password can be set using parameters above accordingly). Format of the string must be \"key1=value1;key2=value2;\".") + @Nullable + public abstract String getConnectionProperties(); + @SchemaFieldDescription( + "Whether to disable auto commit on read. Defaults to true if not provided. The need for this config varies depending on the database platform. Informix requires this to be set to false while Postgres requires this to be set to true.") @Nullable - public abstract String getUsername(); + public abstract Boolean getDisableAutoCommit(); + @SchemaFieldDescription( + "Name of a Java Driver class to use to connect to the JDBC source. For example, \"com.mysql.jdbc.Driver\".") @Nullable - public abstract String getPassword(); + public abstract String getDriverClassName(); + @SchemaFieldDescription( + "Comma separated path(s) for the JDBC driver jar(s). This can be a local path or GCS (gs://) path.") @Nullable - public abstract String getConnectionProperties(); + public abstract String getDriverJars(); + @SchemaFieldDescription( + "This method is used to override the size of the data that is going to be fetched and loaded in memory per every database call. It should ONLY be used if the default value throws memory errors.") @Nullable - public abstract List<@org.checkerframework.checker.nullness.qual.Nullable String> - getConnectionInitSql(); + public abstract Integer getFetchSize(); + @SchemaFieldDescription( + "Type of JDBC source. When specified, an appropriate default Driver will be packaged with the transform. One of mysql, postgres, oracle, or mssql.") @Nullable - public abstract String getReadQuery(); + public abstract String getJdbcType(); + @SchemaFieldDescription("Name of the table to read from.") @Nullable public abstract String getLocation(); + @SchemaFieldDescription( + "Whether to reshuffle the resulting PCollection so results are distributed to all workers.") @Nullable - public abstract Short getFetchSize(); + public abstract Boolean getOutputParallelization(); + @SchemaFieldDescription("Password for the JDBC source.") @Nullable - public abstract Boolean getOutputParallelization(); + public abstract String getPassword(); + @SchemaFieldDescription("SQL query used to query the JDBC source.") @Nullable - public abstract Boolean getDisableAutoCommit(); + public abstract String getReadQuery(); + @SchemaFieldDescription("Username for the JDBC source.") @Nullable - public abstract String getDriverJars(); + public abstract String getUsername(); + + public void validate() { + validate(""); + } - public void validate() throws IllegalArgumentException { + public void validate(String jdbcType) throws IllegalArgumentException { if (Strings.isNullOrEmpty(getJdbcUrl())) { throw new IllegalArgumentException("JDBC URL cannot be blank"); } + jdbcType = !Strings.isNullOrEmpty(jdbcType) ? jdbcType : getJdbcType(); + boolean driverClassNamePresent = !Strings.isNullOrEmpty(getDriverClassName()); - boolean jdbcTypePresent = !Strings.isNullOrEmpty(getJdbcType()); - if (driverClassNamePresent && jdbcTypePresent) { + boolean driverJarsPresent = !Strings.isNullOrEmpty(getDriverJars()); + boolean jdbcTypePresent = !Strings.isNullOrEmpty(jdbcType); + if (!driverClassNamePresent && !driverJarsPresent && !jdbcTypePresent) { throw new IllegalArgumentException( - "JDBC Driver class name and JDBC type are mutually exclusive configurations."); + "If JDBC type is not specified, then Driver Class Name and Driver Jars must be specified."); } if (!driverClassNamePresent && !jdbcTypePresent) { throw new IllegalArgumentException( "One of JDBC Driver class name or JDBC type must be specified."); } if (jdbcTypePresent - && !JDBC_DRIVER_MAP.containsKey(Objects.requireNonNull(getJdbcType()).toLowerCase())) { + && !JDBC_DRIVER_MAP.containsKey(Objects.requireNonNull(jdbcType).toLowerCase())) { throw new IllegalArgumentException("JDBC type must be one of " + JDBC_DRIVER_MAP.keySet()); } @@ -208,11 +342,10 @@ public void validate() throws IllegalArgumentException { boolean locationPresent = (getLocation() != null && !"".equals(getLocation())); if (readQueryPresent && locationPresent) { - throw new IllegalArgumentException( - "ReadQuery and Location are mutually exclusive configurations"); + throw new IllegalArgumentException("Query and Table are mutually exclusive configurations"); } if (!readQueryPresent && !locationPresent) { - throw new IllegalArgumentException("Either ReadQuery or Location must be set."); + throw new IllegalArgumentException("Either Query or Table must be specified."); } } @@ -241,7 +374,7 @@ public abstract static class Builder { public abstract Builder setConnectionInitSql(List value); - public abstract Builder setFetchSize(Short value); + public abstract Builder setFetchSize(Integer value); public abstract Builder setOutputParallelization(Boolean value); diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java index c0f7d68899b3..503b64e4a446 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java @@ -83,20 +83,25 @@ import org.slf4j.LoggerFactory; /** Provides utility functions for working with {@link JdbcIO}. */ -class JdbcUtil { +public class JdbcUtil { private static final Logger LOG = LoggerFactory.getLogger(JdbcUtil.class); + public static final String MYSQL = "mysql"; + public static final String POSTGRES = "postgres"; + public static final String ORACLE = "oracle"; + public static final String MSSQL = "mssql"; + static final Map JDBC_DRIVER_MAP = new HashMap<>( ImmutableMap.of( - "mysql", + MYSQL, "com.mysql.cj.jdbc.Driver", - "postgres", + POSTGRES, "org.postgresql.Driver", - "oracle", + ORACLE, "oracle.jdbc.driver.OracleDriver", - "mssql", + MSSQL, "com.microsoft.sqlserver.jdbc.SQLServerDriver")); @VisibleForTesting diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java index 1f970ba0624f..6f10df56aab5 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; @@ -54,25 +55,131 @@ public class JdbcWriteSchemaTransformProvider extends TypedSchemaTransformProvider< JdbcWriteSchemaTransformProvider.JdbcWriteSchemaTransformConfiguration> { + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:jdbc_write:v1"; + } + + @Override + public String description() { + return "Write to a JDBC sink using a SQL query or by directly accessing a single table.\n" + + "\n" + + "This transform can be used to write to a JDBC sink using either a given JDBC driver jar " + + "and class name, or by using one of the default packaged drivers given a `jdbc_type`.\n" + + "\n" + + "#### Using a default driver\n" + + "\n" + + "This transform comes packaged with drivers for several popular JDBC distributions. The following " + + "distributions can be declared as the `jdbc_type`: " + + JDBC_DRIVER_MAP.keySet().toString().replaceAll("[\\[\\]]", "") + + ".\n" + + "\n" + + "For example, writing to a MySQL sink using a SQL query: ::" + + "\n" + + " - type: WriteToJdbc\n" + + " config:\n" + + " jdbc_type: mysql\n" + + " url: \"jdbc:mysql://my-host:3306/database\"\n" + + " query: \"INSERT INTO table VALUES(?, ?)\"\n" + + "\n" + + "\n" + + "**Note**: See the following transforms which are built on top of this transform and simplify " + + "this logic for several popular JDBC distributions:\n\n" + + " - WriteToMySql\n" + + " - WriteToPostgres\n" + + " - WriteToOracle\n" + + " - WriteToSqlServer\n" + + "\n" + + "#### Declaring custom JDBC drivers\n" + + "\n" + + "If writing to a JDBC sink not listed above, or if it is necessary to use a custom driver not " + + "packaged with Beam, one must define a JDBC driver and class name.\n" + + "\n" + + "For example, writing to a MySQL table: ::" + + "\n" + + " - type: WriteToJdbc\n" + + " config:\n" + + " driver_jars: \"path/to/some/jdbc.jar\"\n" + + " driver_class_name: \"com.mysql.jdbc.Driver\"\n" + + " url: \"jdbc:mysql://my-host:3306/database\"\n" + + " table: \"my-table\"\n" + + "\n" + + "#### Connection Properties\n" + + "\n" + + "Connection properties are properties sent to the Driver used to connect to the JDBC source. For example, " + + "to set the character encoding to UTF-8, one could write: ::\n" + + "\n" + + " - type: WriteToJdbc\n" + + " config:\n" + + " connectionProperties: \"characterEncoding=UTF-8;\"\n" + + " ...\n" + + "All properties should be semi-colon-delimited (e.g. \"key1=value1;key2=value2;\")\n"; + } + + protected String inheritedDescription( + String prettyName, String transformName, String prefix, int port) { + return String.format( + "Write to a %s sink using a SQL query or by directly accessing a single table.%n" + + "%n" + + "This is a special case of WriteToJdbc that includes the " + + "necessary %s Driver and classes.%n" + + "%n" + + "An example of using %s with SQL query: ::%n" + + "%n" + + " - type: %s%n" + + " config:%n" + + " url: \"jdbc:%s://my-host:%d/database\"%n" + + " query: \"INSERT INTO table VALUES(?, ?)\"%n" + + "%n" + + "It is also possible to read a table by specifying a table name. For example, the " + + "following configuration will perform a read on an entire table: ::%n" + + "%n" + + " - type: %s%n" + + " config:%n" + + " url: \"jdbc:%s://my-host:%d/database\"%n" + + " table: \"my-table\"%n" + + "%n" + + "#### Advanced Usage%n" + + "%n" + + "It might be necessary to use a custom JDBC driver that is not packaged with this " + + "transform. If that is the case, see WriteToJdbc which " + + "allows for more custom configuration.", + prettyName, + prettyName, + transformName, + transformName, + prefix, + port, + transformName, + prefix, + port); + } + @Override protected @UnknownKeyFor @NonNull @Initialized Class configurationClass() { return JdbcWriteSchemaTransformConfiguration.class; } + protected String jdbcType() { + return ""; + } + @Override protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from( JdbcWriteSchemaTransformConfiguration configuration) { - configuration.validate(); - return new JdbcWriteSchemaTransform(configuration); + configuration.validate(jdbcType()); + return new JdbcWriteSchemaTransform(configuration, jdbcType()); } - static class JdbcWriteSchemaTransform extends SchemaTransform implements Serializable { + protected static class JdbcWriteSchemaTransform extends SchemaTransform implements Serializable { JdbcWriteSchemaTransformConfiguration config; + private String jdbcType; - public JdbcWriteSchemaTransform(JdbcWriteSchemaTransformConfiguration config) { + public JdbcWriteSchemaTransform(JdbcWriteSchemaTransformConfiguration config, String jdbcType) { this.config = config; + this.jdbcType = jdbcType; } protected JdbcIO.DataSourceConfiguration dataSourceConfiguration() { @@ -80,7 +187,10 @@ protected JdbcIO.DataSourceConfiguration dataSourceConfiguration() { if (Strings.isNullOrEmpty(driverClassName)) { driverClassName = - JDBC_DRIVER_MAP.get(Objects.requireNonNull(config.getJdbcType()).toLowerCase()); + JDBC_DRIVER_MAP.get( + (Objects.requireNonNull( + !Strings.isNullOrEmpty(jdbcType) ? jdbcType : config.getJdbcType())) + .toLowerCase()); } JdbcIO.DataSourceConfiguration dsConfig = @@ -157,11 +267,6 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } } - @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:jdbc_write:v1"; - } - @Override public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> inputCollectionNames() { @@ -178,60 +283,85 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { @DefaultSchema(AutoValueSchema.class) public abstract static class JdbcWriteSchemaTransformConfiguration implements Serializable { + @SchemaFieldDescription("Connection URL for the JDBC sink.") + public abstract String getJdbcUrl(); + + @SchemaFieldDescription( + "If true, enables using a dynamically determined number of shards to write.") @Nullable - public abstract String getDriverClassName(); + public abstract Boolean getAutosharding(); + @SchemaFieldDescription( + "Sets the connection init sql statements used by the Driver. Only MySQL and MariaDB support this.") @Nullable - public abstract String getJdbcType(); + public abstract List<@org.checkerframework.checker.nullness.qual.Nullable String> + getConnectionInitSql(); - public abstract String getJdbcUrl(); + @SchemaFieldDescription( + "Used to set connection properties passed to the JDBC driver not already defined as standalone parameter (e.g. username and password can be set using parameters above accordingly). Format of the string must be \"key1=value1;key2=value2;\".") + @Nullable + public abstract String getConnectionProperties(); + @SchemaFieldDescription( + "Name of a Java Driver class to use to connect to the JDBC source. For example, \"com.mysql.jdbc.Driver\".") @Nullable - public abstract String getUsername(); + public abstract String getDriverClassName(); + @SchemaFieldDescription( + "Comma separated path(s) for the JDBC driver jar(s). This can be a local path or GCS (gs://) path.") @Nullable - public abstract String getPassword(); + public abstract String getDriverJars(); @Nullable - public abstract String getConnectionProperties(); + public abstract Long getBatchSize(); + @SchemaFieldDescription( + "Type of JDBC source. When specified, an appropriate default Driver will be packaged with the transform. One of mysql, postgres, oracle, or mssql.") @Nullable - public abstract List<@org.checkerframework.checker.nullness.qual.Nullable String> - getConnectionInitSql(); + public abstract String getJdbcType(); + @SchemaFieldDescription("Name of the table to write to.") @Nullable public abstract String getLocation(); + @SchemaFieldDescription("Password for the JDBC source.") @Nullable - public abstract String getWriteStatement(); + public abstract String getPassword(); + @SchemaFieldDescription("Username for the JDBC source.") @Nullable - public abstract Boolean getAutosharding(); + public abstract String getUsername(); + @SchemaFieldDescription("SQL query used to insert records into the JDBC sink.") @Nullable - public abstract String getDriverJars(); + public abstract String getWriteStatement(); - @Nullable - public abstract Long getBatchSize(); + public void validate() { + validate(""); + } - public void validate() throws IllegalArgumentException { + public void validate(String jdbcType) throws IllegalArgumentException { if (Strings.isNullOrEmpty(getJdbcUrl())) { throw new IllegalArgumentException("JDBC URL cannot be blank"); } + jdbcType = !Strings.isNullOrEmpty(jdbcType) ? jdbcType : getJdbcType(); + boolean driverClassNamePresent = !Strings.isNullOrEmpty(getDriverClassName()); - boolean jdbcTypePresent = !Strings.isNullOrEmpty(getJdbcType()); - if (driverClassNamePresent && jdbcTypePresent) { + boolean driverJarsPresent = !Strings.isNullOrEmpty(getDriverJars()); + boolean jdbcTypePresent = !Strings.isNullOrEmpty(jdbcType); + if (!driverClassNamePresent && !driverJarsPresent && !jdbcTypePresent) { throw new IllegalArgumentException( - "JDBC Driver class name and JDBC type are mutually exclusive configurations."); + "If JDBC type is not specified, then Driver Class Name and Driver Jars must be specified."); } if (!driverClassNamePresent && !jdbcTypePresent) { throw new IllegalArgumentException( "One of JDBC Driver class name or JDBC type must be specified."); } if (jdbcTypePresent - && !JDBC_DRIVER_MAP.containsKey(Objects.requireNonNull(getJdbcType()).toLowerCase())) { - throw new IllegalArgumentException("JDBC type must be one of " + JDBC_DRIVER_MAP.keySet()); + && !JDBC_DRIVER_MAP.containsKey(Objects.requireNonNull(jdbcType).toLowerCase())) { + throw new IllegalArgumentException( + "JDBC type must be one of " + JDBC_DRIVER_MAP.keySet() + " but was " + jdbcType); } boolean writeStatementPresent = @@ -240,10 +370,10 @@ public void validate() throws IllegalArgumentException { if (writeStatementPresent && locationPresent) { throw new IllegalArgumentException( - "ReadQuery and Location are mutually exclusive configurations"); + "Write Statement and Table are mutually exclusive configurations"); } if (!writeStatementPresent && !locationPresent) { - throw new IllegalArgumentException("Either ReadQuery or Location must be set."); + throw new IllegalArgumentException("Either Write Statement or Table must be set."); } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java new file mode 100644 index 000000000000..3d0135ef8ecd --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.JdbcUtil.MYSQL; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.io.jdbc.JdbcReadSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; + +@AutoService(SchemaTransformProvider.class) +public class ReadFromMySqlSchemaTransformProvider extends JdbcReadSchemaTransformProvider { + + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:mysql_read:v1"; + } + + @Override + public String description() { + return inheritedDescription("MySQL", "ReadFromMySql", "mysql", 3306); + } + + @Override + protected String jdbcType() { + return MYSQL; + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromOracleSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromOracleSchemaTransformProvider.java new file mode 100644 index 000000000000..de18d5aa8189 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromOracleSchemaTransformProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.JdbcUtil.ORACLE; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.io.jdbc.JdbcReadSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; + +@AutoService(SchemaTransformProvider.class) +public class ReadFromOracleSchemaTransformProvider extends JdbcReadSchemaTransformProvider { + + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:oracle_read:v1"; + } + + @Override + public String description() { + return inheritedDescription("Oracle", "ReadFromOracle", "oracle", 1521); + } + + @Override + protected String jdbcType() { + return ORACLE; + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java new file mode 100644 index 000000000000..62ff14c23e0a --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.JdbcUtil.POSTGRES; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.io.jdbc.JdbcReadSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; + +@AutoService(SchemaTransformProvider.class) +public class ReadFromPostgresSchemaTransformProvider extends JdbcReadSchemaTransformProvider { + + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:postgres_read:v1"; + } + + @Override + public String description() { + return inheritedDescription("Postgres", "ReadFromPostgres", "postgresql", 5432); + } + + @Override + protected String jdbcType() { + return POSTGRES; + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromSqlServerSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromSqlServerSchemaTransformProvider.java new file mode 100644 index 000000000000..e4767177bb2f --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromSqlServerSchemaTransformProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.JdbcUtil.MSSQL; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.io.jdbc.JdbcReadSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; + +@AutoService(SchemaTransformProvider.class) +public class ReadFromSqlServerSchemaTransformProvider extends JdbcReadSchemaTransformProvider { + + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:sql_server_read:v1"; + } + + @Override + public String description() { + return inheritedDescription("SQL Server", "ReadFromSqlServer", "sqlserver", 1433); + } + + @Override + protected String jdbcType() { + return MSSQL; + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java new file mode 100644 index 000000000000..57f085220162 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.JdbcUtil.MYSQL; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.io.jdbc.JdbcWriteSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; + +@AutoService(SchemaTransformProvider.class) +public class WriteToMySqlSchemaTransformProvider extends JdbcWriteSchemaTransformProvider { + + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:mysql_write:v1"; + } + + @Override + public String description() { + return inheritedDescription("MySQL", "WriteToMySql", "mysql", 3306); + } + + @Override + protected String jdbcType() { + return MYSQL; + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToOracleSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToOracleSchemaTransformProvider.java new file mode 100644 index 000000000000..5b3ae2c35e9d --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToOracleSchemaTransformProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.JdbcUtil.ORACLE; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.io.jdbc.JdbcWriteSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; + +@AutoService(SchemaTransformProvider.class) +public class WriteToOracleSchemaTransformProvider extends JdbcWriteSchemaTransformProvider { + + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:oracle_write:v1"; + } + + @Override + public String description() { + return inheritedDescription("Oracle", "WriteToOracle", "oracle", 1521); + } + + @Override + protected String jdbcType() { + return ORACLE; + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java new file mode 100644 index 000000000000..c50b84311630 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.JdbcUtil.POSTGRES; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.io.jdbc.JdbcWriteSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; + +@AutoService(SchemaTransformProvider.class) +public class WriteToPostgresSchemaTransformProvider extends JdbcWriteSchemaTransformProvider { + + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:postgres_write:v1"; + } + + @Override + public String description() { + return inheritedDescription("Postgres", "WriteToPostgres", "postgresql", 5432); + } + + @Override + protected String jdbcType() { + return POSTGRES; + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToSqlServerSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToSqlServerSchemaTransformProvider.java new file mode 100644 index 000000000000..9e849f4e49e2 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToSqlServerSchemaTransformProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.JdbcUtil.MSSQL; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.io.jdbc.JdbcWriteSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; + +@AutoService(SchemaTransformProvider.class) +public class WriteToSqlServerSchemaTransformProvider extends JdbcWriteSchemaTransformProvider { + + @Override + public @UnknownKeyFor @NonNull @Initialized String identifier() { + return "beam:schematransform:org.apache.beam:sql_server_write:v1"; + } + + @Override + public String description() { + return inheritedDescription("SQL Server", "WriteToSqlServer", "sqlserver", 1433); + } + + @Override + protected String jdbcType() { + return MSSQL; + } +} diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/package-info.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/package-info.java similarity index 88% rename from sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/package-info.java rename to sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/package-info.java index 1b76a71ae647..db5bba936596 100644 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/package-info.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/package-info.java @@ -15,5 +15,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/** Defines common coders for Amazon Web Services. */ -package org.apache.beam.sdk.io.aws.coders; + +/** Transforms for reading and writing from JDBC. */ +package org.apache.beam.sdk.io.jdbc.providers; diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProviderTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProviderTest.java index 7cbdd48d1587..ca7690ac9a08 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProviderTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProviderTest.java @@ -122,17 +122,6 @@ public void testInvalidReadSchemaOptions() { .build() .validate(); }); - assertThrows( - IllegalArgumentException.class, - () -> { - JdbcReadSchemaTransformProvider.JdbcReadSchemaTransformConfiguration.builder() - .setJdbcUrl("JdbcUrl") - .setLocation("Location") - .setDriverClassName("ClassName") - .setJdbcType((String) JDBC_DRIVER_MAP.keySet().toArray()[0]) - .build() - .validate(); - }); } @Test diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java index d6be4d9f89c8..a8d9162f3a8e 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java @@ -122,17 +122,6 @@ public void testInvalidWriteSchemaOptions() { .build() .validate(); }); - assertThrows( - IllegalArgumentException.class, - () -> { - JdbcWriteSchemaTransformProvider.JdbcWriteSchemaTransformConfiguration.builder() - .setJdbcUrl("JdbcUrl") - .setLocation("Location") - .setDriverClassName("ClassName") - .setJdbcType((String) JDBC_DRIVER_MAP.keySet().toArray()[0]) - .build() - .validate(); - }); } @Test diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index c2f056b0b7cb..04563c478d6d 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -31,7 +31,8 @@ enableJavaPerformanceTesting() description = "Apache Beam :: SDKs :: Java :: IO :: Kafka" ext { summary = "Library to read Kafka topics." - confluentVersion = "7.6.0" + // newer versions e.g. 7.6.* require dropping support for older kafka versions. + confluentVersion = "7.5.5" } def kafkaVersions = [ diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsPipelineOptionsRegistrar.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOInitializer.java similarity index 60% rename from sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsPipelineOptionsRegistrar.java rename to sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOInitializer.java index 3dad9fd611cb..3dfb31715ced 100644 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/options/AwsPipelineOptionsRegistrar.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOInitializer.java @@ -15,22 +15,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.aws.options; +package org.apache.beam.sdk.io.kafka; import com.google.auto.service.AutoService; +import org.apache.beam.sdk.harness.JvmInitializer; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsRegistrar; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; - -/** A registrar containing the default AWS options. */ -@AutoService(PipelineOptionsRegistrar.class) -public class AwsPipelineOptionsRegistrar implements PipelineOptionsRegistrar { +/** Initialize KafkaIO feature flags on worker. */ +@AutoService(JvmInitializer.class) +public class KafkaIOInitializer implements JvmInitializer { @Override - public Iterable> getPipelineOptions() { - return ImmutableList.>builder() - .add(AwsOptions.class) - .add(S3Options.class) - .build(); + public void beforeProcessing(PipelineOptions options) { + if (ExperimentalOptions.hasExperiment(options, "enable_kafka_metrics")) { + KafkaSinkMetrics.setSupportKafkaMetrics(true); + } } } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index 069607955c6d..ab9e26b3b740 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -606,13 +606,20 @@ private void commitCheckpointMark() { LOG.debug("{}: Committing finalized checkpoint {}", this, checkpointMark); Consumer consumer = Preconditions.checkStateNotNull(this.consumer); - consumer.commitSync( - checkpointMark.getPartitions().stream() - .filter(p -> p.getNextOffset() != UNINITIALIZED_OFFSET) - .collect( - Collectors.toMap( - p -> new TopicPartition(p.getTopic(), p.getPartition()), - p -> new OffsetAndMetadata(p.getNextOffset())))); + try { + consumer.commitSync( + checkpointMark.getPartitions().stream() + .filter(p -> p.getNextOffset() != UNINITIALIZED_OFFSET) + .collect( + Collectors.toMap( + p -> new TopicPartition(p.getTopic(), p.getPartition()), + p -> new OffsetAndMetadata(p.getNextOffset())))); + } catch (Exception e) { + // Log but ignore the exception. Committing consumer offsets to Kafka is not critical for + // KafkaIO because it relies on the offsets stored in KafkaCheckpointMark. + LOG.warn( + String.format("%s: Could not commit finalized checkpoint %s", this, checkpointMark), e); + } } } diff --git a/sdks/java/io/kinesis/build.gradle b/sdks/java/io/kinesis/build.gradle deleted file mode 100644 index 60058f4469ad..000000000000 --- a/sdks/java/io/kinesis/build.gradle +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -plugins { id 'org.apache.beam.module' } -applyJavaNature( automaticModuleName: 'org.apache.beam.sdk.io.kinesis') -provideIntegrationTestingDependencies() -enableJavaPerformanceTesting() - -description = "Apache Beam :: SDKs :: Java :: IO :: Kinesis" -ext.summary = "Library to read Kinesis streams." - -test { - maxParallelForks 4 -} - -dependencies { - implementation project(path: ":sdks:java:core", configuration: "shadow") - implementation library.java.aws_java_sdk_cloudwatch - implementation library.java.aws_java_sdk_core - implementation library.java.aws_java_sdk_kinesis - implementation library.java.commons_lang3 - implementation library.java.guava - implementation library.java.joda_time - implementation library.java.slf4j_api - implementation "com.amazonaws:amazon-kinesis-client:1.14.2" - implementation "com.amazonaws:amazon-kinesis-producer:0.14.1" - implementation "commons-lang:commons-lang:2.6" - implementation library.java.vendored_guava_32_1_2_jre - implementation library.java.jackson_core - implementation library.java.jackson_annotations - implementation library.java.jackson_databind - testImplementation project(path: ":sdks:java:io:common") - testImplementation library.java.junit - testImplementation library.java.mockito_core - testImplementation library.java.guava_testlib - testImplementation library.java.powermock - testImplementation library.java.powermock_mockito - testImplementation library.java.testcontainers_localstack - testImplementation "org.assertj:assertj-core:3.11.1" - testRuntimeOnly library.java.slf4j_jdk14 - testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/AWSClientsProvider.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/AWSClientsProvider.java deleted file mode 100644 index fa3351ccf778..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/AWSClientsProvider.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.producer.IKinesisProducer; -import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; -import java.io.Serializable; - -/** - * Provides instances of AWS clients. - * - *

Please note, that any instance of {@link AWSClientsProvider} must be {@link Serializable} to - * ensure it can be sent to worker machines. - */ -public interface AWSClientsProvider extends Serializable { - AmazonKinesis getKinesisClient(); - - AmazonCloudWatch getCloudWatchClient(); - - IKinesisProducer createKinesisProducer(KinesisProducerConfiguration config); -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/BasicKinesisProvider.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/BasicKinesisProvider.java deleted file mode 100644 index ada59996609e..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/BasicKinesisProvider.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder; -import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder; -import com.amazonaws.services.kinesis.producer.IKinesisProducer; -import com.amazonaws.services.kinesis.producer.KinesisProducer; -import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; -import java.net.URI; -import java.util.Objects; -import org.apache.beam.sdk.io.kinesis.serde.AwsSerializableUtils; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** Basic implementation of {@link AWSClientsProvider} used by default in {@link KinesisIO}. */ -class BasicKinesisProvider implements AWSClientsProvider { - private final String awsCredentialsProviderSerialized; - private final Regions region; - private final @Nullable String serviceEndpoint; - private final boolean verifyCertificate; - - BasicKinesisProvider( - AWSCredentialsProvider awsCredentialsProvider, - Regions region, - @Nullable String serviceEndpoint, - boolean verifyCertificate) { - checkArgument(awsCredentialsProvider != null, "awsCredentialsProvider can not be null"); - checkArgument(region != null, "region can not be null"); - this.awsCredentialsProviderSerialized = AwsSerializableUtils.serialize(awsCredentialsProvider); - checkNotNull(awsCredentialsProviderSerialized, "awsCredentialsProviderString can not be null"); - this.region = region; - this.serviceEndpoint = serviceEndpoint; - this.verifyCertificate = verifyCertificate; - } - - private AWSCredentialsProvider getCredentialsProvider() { - return AwsSerializableUtils.deserialize(awsCredentialsProviderSerialized); - } - - @Override - public AmazonKinesis getKinesisClient() { - AmazonKinesisClientBuilder clientBuilder = - AmazonKinesisClientBuilder.standard().withCredentials(getCredentialsProvider()); - if (serviceEndpoint == null) { - clientBuilder.withRegion(region); - } else { - clientBuilder.withEndpointConfiguration( - new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, region.getName())); - } - return clientBuilder.build(); - } - - @Override - public AmazonCloudWatch getCloudWatchClient() { - AmazonCloudWatchClientBuilder clientBuilder = - AmazonCloudWatchClientBuilder.standard().withCredentials(getCredentialsProvider()); - if (serviceEndpoint == null) { - clientBuilder.withRegion(region); - } else { - clientBuilder.withEndpointConfiguration( - new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, region.getName())); - } - return clientBuilder.build(); - } - - @Override - public IKinesisProducer createKinesisProducer(KinesisProducerConfiguration config) { - config.setRegion(region.getName()); - config.setCredentialsProvider(getCredentialsProvider()); - if (serviceEndpoint != null) { - URI uri = URI.create(serviceEndpoint); - config.setKinesisEndpoint(uri.getHost()); - config.setKinesisPort(uri.getPort()); - } - config.setVerifyCertificate(verifyCertificate); - return new KinesisProducer(config); - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - BasicKinesisProvider that = (BasicKinesisProvider) o; - return verifyCertificate == that.verifyCertificate - && Objects.equals(awsCredentialsProviderSerialized, that.awsCredentialsProviderSerialized) - && Objects.equals(region, that.region) - && Objects.equals(serviceEndpoint, that.serviceEndpoint); - } - - @Override - public int hashCode() { - return Objects.hash( - awsCredentialsProviderSerialized, region, serviceEndpoint, verifyCertificate); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CheckpointGenerator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CheckpointGenerator.java deleted file mode 100644 index 08515c7f3457..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CheckpointGenerator.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.io.Serializable; - -/** - * Used to generate checkpoint object on demand. How exactly the checkpoint is generated is up to - * implementing class. - */ -interface CheckpointGenerator extends Serializable { - - KinesisReaderCheckpoint generate(SimplifiedKinesisClient client) throws TransientKinesisException; -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CustomOptional.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CustomOptional.java deleted file mode 100644 index 1baeddd3bf8f..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CustomOptional.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.util.NoSuchElementException; -import java.util.Objects; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** - * Similar to Guava {@code Optional}, but throws {@link NoSuchElementException} for missing element. - */ -abstract class CustomOptional { - - @SuppressWarnings("unchecked") - public static CustomOptional absent() { - return (Absent) Absent.INSTANCE; - } - - public static CustomOptional of(T v) { - return new Present<>(v); - } - - public abstract boolean isPresent(); - - public abstract T get(); - - private static class Present extends CustomOptional { - - private final T value; - - private Present(T value) { - this.value = value; - } - - @Override - public boolean isPresent() { - return true; - } - - @Override - public T get() { - return value; - } - - @Override - public boolean equals(@Nullable Object o) { - if (!(o instanceof Present)) { - return false; - } - - Present present = (Present) o; - return Objects.equals(value, present.value); - } - - @Override - public int hashCode() { - return Objects.hash(value); - } - } - - private static class Absent extends CustomOptional { - - private static final Absent INSTANCE = new Absent<>(); - - private Absent() {} - - @Override - public boolean isPresent() { - return false; - } - - @Override - public T get() { - throw new NoSuchElementException(); - } - - @Override - public boolean equals(@Nullable Object o) { - return o instanceof Absent; - } - - @Override - public int hashCode() { - return 0; - } - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java deleted file mode 100644 index 8ef1274947b5..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import com.amazonaws.services.kinesis.model.Shard; -import java.util.List; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Creates {@link KinesisReaderCheckpoint}, which spans over all shards in given stream. List of - * shards is obtained dynamically on call to {@link #generate(SimplifiedKinesisClient)}. - */ -class DynamicCheckpointGenerator implements CheckpointGenerator { - - private static final Logger LOG = LoggerFactory.getLogger(DynamicCheckpointGenerator.class); - private final String streamName; - private final StartingPoint startingPoint; - - public DynamicCheckpointGenerator(String streamName, StartingPoint startingPoint) { - this.streamName = streamName; - this.startingPoint = startingPoint; - } - - @Override - public KinesisReaderCheckpoint generate(SimplifiedKinesisClient kinesis) - throws TransientKinesisException { - List streamShards = kinesis.listShardsAtPoint(streamName, startingPoint); - LOG.info( - "Creating a checkpoint with following shards {} at {}", - streamShards, - startingPoint.getTimestamp()); - return new KinesisReaderCheckpoint( - streamShards.stream() - .map(shard -> new ShardCheckpoint(streamName, shard.getShardId(), startingPoint)) - .collect(Collectors.toList())); - } - - @Override - public String toString() { - return String.format("Checkpoint generator for %s: %s", streamName, startingPoint); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/GetKinesisRecordsResult.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/GetKinesisRecordsResult.java deleted file mode 100644 index 6fefb43dee0f..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/GetKinesisRecordsResult.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord; -import java.util.List; -import java.util.stream.Collectors; - -/** Represents the output of 'get' operation on Kinesis stream. */ -class GetKinesisRecordsResult { - - private final List records; - private final String nextShardIterator; - private final long millisBehindLatest; - - public GetKinesisRecordsResult( - List records, - String nextShardIterator, - long millisBehindLatest, - final String streamName, - final String shardId) { - this.records = - records.stream() - .map( - input -> { - assert input != null; // to make FindBugs happy - return new KinesisRecord(input, streamName, shardId); - }) - .collect(Collectors.toList()); - this.nextShardIterator = nextShardIterator; - this.millisBehindLatest = millisBehindLatest; - } - - public List getRecords() { - return records; - } - - public String getNextShardIterator() { - return nextShardIterator; - } - - public long getMillisBehindLatest() { - return millisBehindLatest; - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java deleted file mode 100644 index 0cf4bdb0d85b..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import com.amazonaws.AmazonClientException; - -/** Thrown when the Kinesis client was throttled due to rate limits. */ -public class KinesisClientThrottledException extends TransientKinesisException { - - public KinesisClientThrottledException(String s, AmazonClientException e) { - super(s, e); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java deleted file mode 100644 index 3b64a6e71947..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java +++ /dev/null @@ -1,1116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.amazonaws.services.kinesis.producer.Attempt; -import com.amazonaws.services.kinesis.producer.IKinesisProducer; -import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; -import com.amazonaws.services.kinesis.producer.UserRecordFailedException; -import com.amazonaws.services.kinesis.producer.UserRecordResult; -import com.google.auto.value.AutoValue; -import com.google.common.util.concurrent.ListenableFuture; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import java.util.concurrent.LinkedBlockingDeque; -import java.util.function.Supplier; -import org.apache.beam.sdk.coders.ByteArrayCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.io.Read.Unbounded; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.values.PBegin; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PDone; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * {@link PTransform}s for reading from and writing to Kinesis streams. - * - *

Reading from Kinesis

- * - *

Example usages: - * - *

{@code
- * p.apply(KinesisIO.read()
- *     .withStreamName("streamName")
- *     .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *     // using AWS default credentials provider chain (recommended)
- *     .withAWSClientsProvider(DefaultAWSCredentialsProviderChain.getInstance(), STREAM_REGION)
- *  .apply( ... ) // other transformations
- * }
- * - *
{@code
- * p.apply(KinesisIO.read()
- *     .withStreamName("streamName")
- *     .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *     // using plain AWS key and secret
- *     .withAWSClientsProvider("AWS_KEY", "AWS_SECRET", STREAM_REGION)
- *  .apply( ... ) // other transformations
- * }
- * - *

As you can see you need to provide 3 things: - * - *

    - *
  • name of the stream you're going to read - *
  • position in the stream where reading should start. There are two options: - *
      - *
    • {@link InitialPositionInStream#LATEST} - reading will begin from end of the stream - *
    • {@link InitialPositionInStream#TRIM_HORIZON} - reading will begin at the very - * beginning of the stream - *
    - *
  • data used to initialize {@link AmazonKinesis} and {@link AmazonCloudWatch} clients: - *
      - *
    • AWS credentials - *
    • region where the stream is located - *
    - *
- * - *

In case when you want to set up {@link AmazonKinesis} or {@link AmazonCloudWatch} client by - * your own (for example if you're using more sophisticated authorization methods like Amazon STS, - * etc.) you can do it by implementing {@link AWSClientsProvider} class: - * - *

{@code
- * public class MyCustomKinesisClientProvider implements AWSClientsProvider {
- *   public AmazonKinesis getKinesisClient() {
- *     // set up your client here
- *   }
- *
- *   public AmazonCloudWatch getCloudWatchClient() {
- *     // set up your client here
- *   }
- *
- * }
- * }
- * - *

Usage is pretty straightforward: - * - *

{@code
- * p.apply(KinesisIO.read()
- *    .withStreamName("streamName")
- *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *    .withAWSClientsProvider(new MyCustomKinesisClientProvider())
- *  .apply( ... ) // other transformations
- * }
- * - *

There's also possibility to start reading using arbitrary point in time - in this case you - * need to provide {@link Instant} object: - * - *

{@code
- * p.apply(KinesisIO.read()
- *     .withStreamName("streamName")
- *     .withInitialTimestampInStream(instant)
- *     .withAWSClientsProvider(new MyCustomKinesisClientProvider())
- *  .apply( ... ) // other transformations
- * }
- * - *

Kinesis IO uses ArrivalTimeWatermarkPolicy by default. To use Processing time as event time: - * - *

{@code
- * p.apply(KinesisIO.read()
- *    .withStreamName("streamName")
- *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *    .withProcessingTimeWatermarkPolicy())
- * }
- * - *

It is also possible to specify a custom watermark policy to control watermark computation. - * Below is an example - * - *

{@code
- * // custom policy
- * class MyCustomPolicy implements WatermarkPolicy {
- *     private WatermarkPolicyFactory.CustomWatermarkPolicy customWatermarkPolicy;
- *
- *     MyCustomPolicy() {
- *       this.customWatermarkPolicy = new WatermarkPolicyFactory.CustomWatermarkPolicy(WatermarkParameters.create());
- *     }
- *
- *     public Instant getWatermark() {
- *       return customWatermarkPolicy.getWatermark();
- *     }
- *
- *     public void update(KinesisRecord record) {
- *       customWatermarkPolicy.update(record);
- *     }
- *   }
- *
- * // custom factory
- * class MyCustomPolicyFactory implements WatermarkPolicyFactory {
- *     public WatermarkPolicy createWatermarkPolicy() {
- *       return new MyCustomPolicy();
- *     }
- * }
- *
- * p.apply(KinesisIO.read()
- *    .withStreamName("streamName")
- *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *    .withCustomWatermarkPolicy(new MyCustomPolicyFactory())
- * }
- * - *

By default Kinesis IO will poll the Kinesis getRecords() API as fast as possible which may - * lead to excessive read throttling. To limit the rate of getRecords() calls you can set a rate - * limit policy. For example, the default fixed delay policy will limit the rate to one API call per - * second per shard: - * - *

{@code
- * p.apply(KinesisIO.read()
- *    .withStreamName("streamName")
- *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *    .withFixedDelayRateLimitPolicy())
- * }
- * - *

You can also use a fixed delay policy with a specified delay interval, for example: - * - *

{@code
- * p.apply(KinesisIO.read()
- *    .withStreamName("streamName")
- *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *    .withFixedDelayRateLimitPolicy(Duration.millis(500))
- * }
- * - *

If you need to change the polling interval of a Kinesis pipeline at runtime, for example to - * compensate for adding and removing additional consumers to the stream, then you can supply the - * delay interval as a function so that you can obtain the current delay interval from some external - * source: - * - *

{@code
- * p.apply(KinesisIO.read()
- *    .withStreamName("streamName")
- *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *    .withDynamicDelayRateLimitPolicy(() -> Duration.millis())
- * }
- * - *

Finally, you can create a custom rate limit policy that responds to successful read calls - * and/or read throttling exceptions with your own rate-limiting logic: - * - *

{@code
- * // custom policy
- * public class MyCustomPolicy implements RateLimitPolicy {
- *
- *   public void onSuccess(List records) throws InterruptedException {
- *     // handle successful getRecords() call
- *   }
- *
- *   public void onThrottle(KinesisClientThrottledException e) throws InterruptedException {
- *     // handle Kinesis read throttling exception
- *   }
- * }
- *
- * // custom factory
- * class MyCustomPolicyFactory implements RateLimitPolicyFactory {
- *
- *   public RateLimitPolicy getRateLimitPolicy() {
- *     return new MyCustomPolicy();
- *   }
- * }
- *
- * p.apply(KinesisIO.read()
- *    .withStreamName("streamName")
- *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
- *    .withCustomRateLimitPolicy(new MyCustomPolicyFactory())
- * }
- * - *

Writing to Kinesis

- * - *

Example usages: - * - *

{@code
- * PCollection data = ...;
- *
- * data.apply(KinesisIO.write()
- *     .withStreamName("streamName")
- *     .withPartitionKey("partitionKey")
- *     // using AWS default credentials provider chain (recommended)
- *     .withAWSClientsProvider(DefaultAWSCredentialsProviderChain.getInstance(), STREAM_REGION));
- * }
- * - *
{@code
- * PCollection data = ...;
- *
- * data.apply(KinesisIO.write()
- *     .withStreamName("streamName")
- *     .withPartitionKey("partitionKey")
- *      // using plain AWS key and secret
- *     .withAWSClientsProvider("AWS_KEY", "AWS_SECRET", STREAM_REGION));
- * }
- * - *

As a client, you need to provide at least 3 things: - * - *

    - *
  • name of the stream where you're going to write - *
  • partition key (or implementation of {@link KinesisPartitioner}) that defines which - * partition will be used for writing - *
  • data used to initialize {@link AmazonKinesis} and {@link AmazonCloudWatch} clients: - *
      - *
    • AWS credentials - *
    • region where the stream is located - *
    - *
- * - *

In case if you need to define more complicated logic for key partitioning then you can create - * your own implementation of {@link KinesisPartitioner} and set it by {@link - * KinesisIO.Write#withPartitioner(KinesisPartitioner)} - * - *

Internally, {@link KinesisIO.Write} relies on Amazon Kinesis Producer Library (KPL). This - * library can be configured with a set of {@link Properties} if needed. - * - *

Example usage of KPL configuration: - * - *

{@code
- * Properties properties = new Properties();
- * properties.setProperty("KinesisEndpoint", "localhost");
- * properties.setProperty("KinesisPort", "4567");
- *
- * PCollection data = ...;
- *
- * data.apply(KinesisIO.write()
- *     .withStreamName("streamName")
- *     .withPartitionKey("partitionKey")
- *     .withAWSClientsProvider(AWS_KEY, AWS_SECRET, STREAM_REGION)
- *     .withProducerProperties(properties));
- * }
- * - *

For more information about configuratiom parameters, see the sample - * of configuration file. - * - * @deprecated Module beam-sdks-java-io-kinesis is deprecated and will be eventually - * removed. Please migrate to {@link org.apache.beam.sdk.io.aws2.kinesis.KinesisIO} in module - * beam-sdks-java-io-amazon-web-services2. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -@Deprecated -public final class KinesisIO { - - private static final Logger LOG = LoggerFactory.getLogger(KinesisIO.class); - - private static final int DEFAULT_NUM_RETRIES = 6; - - /** Returns a new {@link Read} transform for reading from Kinesis. */ - public static Read read() { - return Read.newBuilder().setCoder(KinesisRecordCoder.of()).build(); - } - - /** - * A {@link PTransform} to read from Kinesis stream as bytes without metadata and returns a {@link - * PCollection} of {@link byte[]}. - */ - public static Read readData() { - return Read.newBuilder(KinesisRecord::getDataAsBytes).setCoder(ByteArrayCoder.of()).build(); - } - - /** A {@link PTransform} writing data to Kinesis. */ - public static Write write() { - return new AutoValue_KinesisIO_Write.Builder().setRetries(DEFAULT_NUM_RETRIES).build(); - } - - /** Implementation of {@link #read}. */ - @AutoValue - public abstract static class Read extends PTransform> { - - abstract @Nullable String getStreamName(); - - abstract @Nullable StartingPoint getInitialPosition(); - - abstract @Nullable AWSClientsProvider getAWSClientsProvider(); - - abstract long getMaxNumRecords(); - - abstract @Nullable Duration getMaxReadTime(); - - abstract Duration getUpToDateThreshold(); - - abstract @Nullable Integer getRequestRecordsLimit(); - - abstract WatermarkPolicyFactory getWatermarkPolicyFactory(); - - abstract RateLimitPolicyFactory getRateLimitPolicyFactory(); - - abstract Integer getMaxCapacityPerShard(); - - abstract Coder getCoder(); - - abstract @Nullable SerializableFunction getParseFn(); - - abstract Builder toBuilder(); - - static Builder newBuilder(SerializableFunction parseFn) { - return new AutoValue_KinesisIO_Read.Builder() - .setParseFn(parseFn) - .setMaxNumRecords(Long.MAX_VALUE) - .setUpToDateThreshold(Duration.ZERO) - .setWatermarkPolicyFactory(WatermarkPolicyFactory.withArrivalTimePolicy()) - .setRateLimitPolicyFactory(RateLimitPolicyFactory.withDefaultRateLimiter()) - .setMaxCapacityPerShard(ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD); - } - - static Builder newBuilder() { - return newBuilder(x -> x); - } - - @AutoValue.Builder - abstract static class Builder { - - abstract Builder setStreamName(String streamName); - - abstract Builder setInitialPosition(StartingPoint startingPoint); - - abstract Builder setAWSClientsProvider(AWSClientsProvider clientProvider); - - abstract Builder setMaxNumRecords(long maxNumRecords); - - abstract Builder setMaxReadTime(Duration maxReadTime); - - abstract Builder setUpToDateThreshold(Duration upToDateThreshold); - - abstract Builder setRequestRecordsLimit(Integer limit); - - abstract Builder setWatermarkPolicyFactory(WatermarkPolicyFactory watermarkPolicyFactory); - - abstract Builder setRateLimitPolicyFactory(RateLimitPolicyFactory rateLimitPolicyFactory); - - abstract Builder setMaxCapacityPerShard(Integer maxCapacity); - - abstract Builder setParseFn(SerializableFunction parseFn); - - abstract Builder setCoder(Coder coder); - - abstract Read build(); - } - - /** Specify reading from streamName. */ - public Read withStreamName(String streamName) { - return toBuilder().setStreamName(streamName).build(); - } - - /** Specify reading from some initial position in stream. */ - public Read withInitialPositionInStream(InitialPositionInStream initialPosition) { - return toBuilder().setInitialPosition(new StartingPoint(initialPosition)).build(); - } - - /** - * Specify reading beginning at given {@link Instant}. This {@link Instant} must be in the past, - * i.e. before {@link Instant#now()}. - */ - public Read withInitialTimestampInStream(Instant initialTimestamp) { - return toBuilder().setInitialPosition(new StartingPoint(initialTimestamp)).build(); - } - - /** - * Allows to specify custom {@link AWSClientsProvider}. {@link AWSClientsProvider} provides - * {@link AmazonKinesis} and {@link AmazonCloudWatch} instances which are later used for - * communication with Kinesis. You should use this method if {@link - * Read#withAWSClientsProvider(AWSCredentialsProvider, Regions)} does not suit your needs. - */ - public Read withAWSClientsProvider(AWSClientsProvider awsClientsProvider) { - return toBuilder().setAWSClientsProvider(awsClientsProvider).build(); - } - - /** - * Specify {@link AWSCredentialsProvider} and region to be used to read from Kinesis. If you - * need more sophisticated credential protocol, then you should look at {@link - * Read#withAWSClientsProvider(AWSClientsProvider)}. - */ - public Read withAWSClientsProvider( - AWSCredentialsProvider awsCredentialsProvider, Regions region) { - return withAWSClientsProvider(awsCredentialsProvider, region, null); - } - - /** - * Specify credential details and region to be used to read from Kinesis. If you need more - * sophisticated credential protocol, then you should look at {@link - * Read#withAWSClientsProvider(AWSClientsProvider)}. - */ - public Read withAWSClientsProvider( - String awsAccessKey, String awsSecretKey, Regions region) { - return withAWSClientsProvider(awsAccessKey, awsSecretKey, region, null); - } - - /** - * Specify {@link AWSCredentialsProvider} and region to be used to read from Kinesis. If you - * need more sophisticated credential protocol, then you should look at {@link - * Read#withAWSClientsProvider(AWSClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with a kinesis service emulator. - */ - public Read withAWSClientsProvider( - AWSCredentialsProvider awsCredentialsProvider, Regions region, String serviceEndpoint) { - return withAWSClientsProvider(awsCredentialsProvider, region, serviceEndpoint, true); - } - - /** - * Specify credential details and region to be used to read from Kinesis. If you need more - * sophisticated credential protocol, then you should look at {@link - * Read#withAWSClientsProvider(AWSClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with a kinesis service emulator. - */ - public Read withAWSClientsProvider( - String awsAccessKey, String awsSecretKey, Regions region, String serviceEndpoint) { - return withAWSClientsProvider(awsAccessKey, awsSecretKey, region, serviceEndpoint, true); - } - - /** - * Specify {@link AWSCredentialsProvider} and region to be used to read from Kinesis. If you - * need more sophisticated credential protocol, then you should look at {@link - * Read#withAWSClientsProvider(AWSClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with Kinesis service emulator. - * - *

The {@code verifyCertificate} disables or enables certificate verification. Never set it - * to false in production. - */ - public Read withAWSClientsProvider( - AWSCredentialsProvider awsCredentialsProvider, - Regions region, - String serviceEndpoint, - boolean verifyCertificate) { - return withAWSClientsProvider( - new BasicKinesisProvider( - awsCredentialsProvider, region, serviceEndpoint, verifyCertificate)); - } - - /** - * Specify credential details and region to be used to read from Kinesis. If you need more - * sophisticated credential protocol, then you should look at {@link - * Read#withAWSClientsProvider(AWSClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with Kinesis service emulator. - * - *

The {@code verifyCertificate} disables or enables certificate verification. Never set it - * to false in production. - */ - public Read withAWSClientsProvider( - String awsAccessKey, - String awsSecretKey, - Regions region, - String serviceEndpoint, - boolean verifyCertificate) { - AWSCredentialsProvider awsCredentialsProvider = - new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKey, awsSecretKey)); - return withAWSClientsProvider( - awsCredentialsProvider, region, serviceEndpoint, verifyCertificate); - } - - /** Specifies to read at most a given number of records. */ - public Read withMaxNumRecords(long maxNumRecords) { - checkArgument( - maxNumRecords > 0, "maxNumRecords must be positive, but was: %s", maxNumRecords); - return toBuilder().setMaxNumRecords(maxNumRecords).build(); - } - - /** Specifies to read records during {@code maxReadTime}. */ - public Read withMaxReadTime(Duration maxReadTime) { - checkArgument(maxReadTime != null, "maxReadTime can not be null"); - return toBuilder().setMaxReadTime(maxReadTime).build(); - } - - /** - * Specifies how late records consumed by this source can be to still be considered on time. - * When this limit is exceeded the actual backlog size will be evaluated and the runner might - * decide to scale the amount of resources allocated to the pipeline in order to speed up - * ingestion. - */ - public Read withUpToDateThreshold(Duration upToDateThreshold) { - checkArgument(upToDateThreshold != null, "upToDateThreshold can not be null"); - return toBuilder().setUpToDateThreshold(upToDateThreshold).build(); - } - - /** - * Specifies the maximum number of records in GetRecordsResult returned by GetRecords call which - * is limited by 10K records. If should be adjusted according to average size of data record to - * prevent shard overloading. More details can be found here: API_GetRecords - */ - public Read withRequestRecordsLimit(int limit) { - checkArgument(limit > 0, "limit must be positive, but was: %s", limit); - checkArgument(limit <= 10_000, "limit must be up to 10,000, but was: %s", limit); - return toBuilder().setRequestRecordsLimit(limit).build(); - } - - /** Specifies the {@code WatermarkPolicyFactory} as ArrivalTimeWatermarkPolicyFactory. */ - public Read withArrivalTimeWatermarkPolicy() { - return toBuilder() - .setWatermarkPolicyFactory(WatermarkPolicyFactory.withArrivalTimePolicy()) - .build(); - } - - /** - * Specifies the {@code WatermarkPolicyFactory} as ArrivalTimeWatermarkPolicyFactory. - * - *

{@param watermarkIdleDurationThreshold} Denotes the duration for which the watermark can - * be idle. - */ - public Read withArrivalTimeWatermarkPolicy(Duration watermarkIdleDurationThreshold) { - return toBuilder() - .setWatermarkPolicyFactory( - WatermarkPolicyFactory.withArrivalTimePolicy(watermarkIdleDurationThreshold)) - .build(); - } - - /** Specifies the {@code WatermarkPolicyFactory} as ProcessingTimeWatermarkPolicyFactory. */ - public Read withProcessingTimeWatermarkPolicy() { - return toBuilder() - .setWatermarkPolicyFactory(WatermarkPolicyFactory.withProcessingTimePolicy()) - .build(); - } - - /** - * Specifies the {@code WatermarkPolicyFactory} as a custom watermarkPolicyFactory. - * - * @param watermarkPolicyFactory Custom Watermark policy factory. - */ - public Read withCustomWatermarkPolicy(WatermarkPolicyFactory watermarkPolicyFactory) { - checkArgument(watermarkPolicyFactory != null, "watermarkPolicyFactory cannot be null"); - return toBuilder().setWatermarkPolicyFactory(watermarkPolicyFactory).build(); - } - - /** Specifies a fixed delay rate limit policy with the default delay of 1 second. */ - public Read withFixedDelayRateLimitPolicy() { - return toBuilder().setRateLimitPolicyFactory(RateLimitPolicyFactory.withFixedDelay()).build(); - } - - /** - * Specifies a fixed delay rate limit policy with the given delay. - * - * @param delay Denotes the fixed delay duration. - */ - public Read withFixedDelayRateLimitPolicy(Duration delay) { - checkArgument(delay != null, "delay cannot be null"); - return toBuilder() - .setRateLimitPolicyFactory(RateLimitPolicyFactory.withFixedDelay(delay)) - .build(); - } - - /** - * Specifies a dynamic delay rate limit policy with the given function being called at each - * polling interval to get the next delay value. This can be used to change the polling interval - * of a running pipeline based on some external configuration source, for example. - * - * @param delay The function to invoke to get the next delay duration. - */ - public Read withDynamicDelayRateLimitPolicy(Supplier delay) { - checkArgument(delay != null, "delay cannot be null"); - return toBuilder().setRateLimitPolicyFactory(RateLimitPolicyFactory.withDelay(delay)).build(); - } - - /** - * Specifies the {@code RateLimitPolicyFactory} for a custom rate limiter. - * - * @param rateLimitPolicyFactory Custom rate limit policy factory. - */ - public Read withCustomRateLimitPolicy(RateLimitPolicyFactory rateLimitPolicyFactory) { - checkArgument(rateLimitPolicyFactory != null, "rateLimitPolicyFactory cannot be null"); - return toBuilder().setRateLimitPolicyFactory(rateLimitPolicyFactory).build(); - } - - /** Specifies the maximum number of messages per one shard. */ - public Read withMaxCapacityPerShard(Integer maxCapacity) { - checkArgument(maxCapacity > 0, "maxCapacity must be positive, but was: %s", maxCapacity); - return toBuilder().setMaxCapacityPerShard(maxCapacity).build(); - } - - @Override - public PCollection expand(PBegin input) { - LOG.warn( - "You are using a deprecated IO for Kinesis. Please migrate to module " - + "'org.apache.beam:beam-sdks-java-io-amazon-web-services2'."); - - Unbounded unbounded = - org.apache.beam.sdk.io.Read.from( - new KinesisSource( - getAWSClientsProvider(), - getStreamName(), - getInitialPosition(), - getUpToDateThreshold(), - getWatermarkPolicyFactory(), - getRateLimitPolicyFactory(), - getRequestRecordsLimit(), - getMaxCapacityPerShard())); - - PTransform> transform = unbounded; - - if (getMaxNumRecords() < Long.MAX_VALUE || getMaxReadTime() != null) { - transform = - unbounded.withMaxReadTime(getMaxReadTime()).withMaxNumRecords(getMaxNumRecords()); - } - - return input - .apply(transform) - .apply(MapElements.into(new TypeDescriptor() {}).via(getParseFn())) - .setCoder(getCoder()); - } - } - - /** Implementation of {@link #write}. */ - @AutoValue - public abstract static class Write extends PTransform, PDone> { - - abstract @Nullable String getStreamName(); - - abstract @Nullable String getPartitionKey(); - - abstract @Nullable KinesisPartitioner getPartitioner(); - - abstract @Nullable Properties getProducerProperties(); - - abstract @Nullable AWSClientsProvider getAWSClientsProvider(); - - abstract int getRetries(); - - abstract Builder builder(); - - @AutoValue.Builder - abstract static class Builder { - abstract Builder setStreamName(String streamName); - - abstract Builder setPartitionKey(String partitionKey); - - abstract Builder setPartitioner(KinesisPartitioner partitioner); - - abstract Builder setProducerProperties(Properties properties); - - abstract Builder setAWSClientsProvider(AWSClientsProvider clientProvider); - - abstract Builder setRetries(int retries); - - abstract Write build(); - } - - /** Specify Kinesis stream name which will be used for writing, this name is required. */ - public Write withStreamName(String streamName) { - return builder().setStreamName(streamName).build(); - } - - /** - * Specify default partition key. - * - *

In case if you need to define more complicated logic for key partitioning then you can - * create your own implementation of {@link KinesisPartitioner} and specify it by {@link - * KinesisIO.Write#withPartitioner(KinesisPartitioner)} - * - *

Using one of the methods {@link KinesisIO.Write#withPartitioner(KinesisPartitioner)} or - * {@link KinesisIO.Write#withPartitionKey(String)} is required but not both in the same time. - */ - public Write withPartitionKey(String partitionKey) { - return builder().setPartitionKey(partitionKey).build(); - } - - /** - * Allows to specify custom implementation of {@link KinesisPartitioner}. - * - *

This method should be used to balance a distribution of new written records among all - * stream shards. - * - *

Using one of the methods {@link KinesisIO.Write#withPartitioner(KinesisPartitioner)} or - * {@link KinesisIO.Write#withPartitionKey(String)} is required but not both in the same time. - */ - public Write withPartitioner(KinesisPartitioner partitioner) { - return builder().setPartitioner(partitioner).build(); - } - - /** - * Specify the configuration properties for Kinesis Producer Library (KPL). - * - *

Example of creating new KPL configuration: - * - *

{@code Properties properties = new Properties(); - * properties.setProperty("CollectionMaxCount", "1000"); - * properties.setProperty("ConnectTimeout", "10000");} - */ - public Write withProducerProperties(Properties properties) { - return builder().setProducerProperties(properties).build(); - } - - /** - * Allows to specify custom {@link AWSClientsProvider}. {@link AWSClientsProvider} creates new - * {@link IKinesisProducer} which is later used for writing to Kinesis. - * - *

This method should be used if {@link Write#withAWSClientsProvider(AWSCredentialsProvider, - * Regions)} does not suit well. - */ - public Write withAWSClientsProvider(AWSClientsProvider awsClientsProvider) { - return builder().setAWSClientsProvider(awsClientsProvider).build(); - } - - /** - * Specify {@link AWSCredentialsProvider} and region to be used to write to Kinesis. If you need - * more sophisticated credential protocol, then you should look at {@link - * Write#withAWSClientsProvider(AWSClientsProvider)}. - */ - public Write withAWSClientsProvider( - AWSCredentialsProvider awsCredentialsProvider, Regions region) { - return withAWSClientsProvider(awsCredentialsProvider, region, null); - } - - /** - * Specify credential details and region to be used to write to Kinesis. If you need more - * sophisticated credential protocol, then you should look at {@link - * Write#withAWSClientsProvider(AWSClientsProvider)}. - */ - public Write withAWSClientsProvider(String awsAccessKey, String awsSecretKey, Regions region) { - return withAWSClientsProvider(awsAccessKey, awsSecretKey, region, null); - } - - /** - * Specify {@link AWSCredentialsProvider} and region to be used to write to Kinesis. If you need - * more sophisticated credential protocol, then you should look at {@link - * Write#withAWSClientsProvider(AWSClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with Kinesis service emulator. - */ - public Write withAWSClientsProvider( - AWSCredentialsProvider awsCredentialsProvider, Regions region, String serviceEndpoint) { - return withAWSClientsProvider(awsCredentialsProvider, region, serviceEndpoint, true); - } - - /** - * Specify credential details and region to be used to write to Kinesis. If you need more - * sophisticated credential protocol, then you should look at {@link - * Write#withAWSClientsProvider(AWSClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with Kinesis service emulator. - */ - public Write withAWSClientsProvider( - String awsAccessKey, String awsSecretKey, Regions region, String serviceEndpoint) { - return withAWSClientsProvider(awsAccessKey, awsSecretKey, region, serviceEndpoint, true); - } - - /** - * Specify credential details and region to be used to write to Kinesis. If you need more - * sophisticated credential protocol, then you should look at {@link - * Write#withAWSClientsProvider(AWSClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with Kinesis service emulator. - * - *

The {@code verifyCertificate} disables or enables certificate verification. Never set it - * to false in production. - */ - public Write withAWSClientsProvider( - AWSCredentialsProvider awsCredentialsProvider, - Regions region, - String serviceEndpoint, - boolean verifyCertificate) { - return withAWSClientsProvider( - new BasicKinesisProvider( - awsCredentialsProvider, region, serviceEndpoint, verifyCertificate)); - } - - /** - * Specify credential details and region to be used to write to Kinesis. If you need more - * sophisticated credential protocol, then you should look at {@link - * Write#withAWSClientsProvider(AWSClientsProvider)}. - * - *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute - * the tests with Kinesis service emulator. - * - *

The {@code verifyCertificate} disables or enables certificate verification. Never set it - * to false in production. - */ - public Write withAWSClientsProvider( - String awsAccessKey, - String awsSecretKey, - Regions region, - String serviceEndpoint, - boolean verifyCertificate) { - AWSCredentialsProvider awsCredentialsProvider = - new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKey, awsSecretKey)); - return withAWSClientsProvider( - awsCredentialsProvider, region, serviceEndpoint, verifyCertificate); - } - - /** - * Specify the number of retries that will be used to flush the outstanding records in case if - * they were not flushed from the first time. Default number of retries is {@code - * DEFAULT_NUM_RETRIES = 10}. - * - *

This is used for testing. - */ - @VisibleForTesting - Write withRetries(int retries) { - return builder().setRetries(retries).build(); - } - - @Override - public PDone expand(PCollection input) { - LOG.warn( - "You are using a deprecated IO for Kinesis. Please migrate to module " - + "'org.apache.beam:beam-sdks-java-io-amazon-web-services2'."); - - checkArgument(getStreamName() != null, "withStreamName() is required"); - checkArgument( - (getPartitionKey() != null) || (getPartitioner() != null), - "withPartitionKey() or withPartitioner() is required"); - checkArgument( - getPartitionKey() == null || (getPartitioner() == null), - "only one of either withPartitionKey() or withPartitioner() is possible"); - checkArgument(getAWSClientsProvider() != null, "withAWSClientsProvider() is required"); - createProducerConfiguration(); // verify Kinesis producer configuration can be built - - input.apply(ParDo.of(new KinesisWriterFn(this))); - return PDone.in(input.getPipeline()); - } - - private KinesisProducerConfiguration createProducerConfiguration() { - Properties props = getProducerProperties(); - if (props == null) { - props = new Properties(); - } - return KinesisProducerConfiguration.fromProperties(props); - } - - private static class KinesisWriterFn extends DoFn { - private static final int MAX_NUM_FAILURES = 10; - - /** Usage count of static, shared Kinesis producer. */ - private static int producerRefCount = 0; - - /** Static, shared Kinesis producer. */ - private static IKinesisProducer producer; - - private final KinesisIO.Write spec; - - private transient KinesisPartitioner partitioner; - private transient LinkedBlockingDeque failures; - private transient List> putFutures; - - KinesisWriterFn(KinesisIO.Write spec) { - this.spec = spec; - } - - /** - * Initialize statically shared Kinesis producer if required and count usage. - * - *

NOTE: If there is, for whatever reasons, another instance of a {@link KinesisWriterFn} - * with different producer properties or even a different implementation of {@link - * AWSClientsProvider}, these changes will be silently discarded in favor of an existing - * producer instance. - */ - private void setupSharedProducer() { - synchronized (KinesisWriterFn.class) { - if (producer == null) { - producer = - spec.getAWSClientsProvider() - .createKinesisProducer(spec.createProducerConfiguration()); - producerRefCount = 0; - } - producerRefCount++; - } - } - - /** - * Discard statically shared producer if it is not used anymore according to the usage count. - */ - private void teardownSharedProducer() { - IKinesisProducer obsolete = null; - synchronized (KinesisWriterFn.class) { - if (--producerRefCount == 0) { - obsolete = producer; - producer = null; - } - } - if (obsolete != null) { - obsolete.flushSync(); // should be a noop, but just in case - obsolete.destroy(); - } - } - - @Setup - public void setup() { - setupSharedProducer(); - // Use custom partitioner if it exists - if (spec.getPartitioner() != null) { - partitioner = spec.getPartitioner(); - } - } - - @StartBundle - public void startBundle() { - putFutures = Collections.synchronizedList(new ArrayList<>()); - /** Keep only the first {@link MAX_NUM_FAILURES} occurred exceptions */ - failures = new LinkedBlockingDeque<>(MAX_NUM_FAILURES); - } - - /** - * It adds a record asynchronously which then should be delivered by Kinesis producer in - * background (Kinesis producer forks native processes to do this job). - * - *

The records can be batched and then they will be sent in one HTTP request. Amazon KPL - * supports two types of batching - aggregation and collection - and they can be configured by - * producer properties. - * - *

More details can be found here: KPL Key - * Concepts and Configuring - * the KPL - */ - @ProcessElement - public void processElement(ProcessContext c) { - ByteBuffer data = ByteBuffer.wrap(c.element()); - String partitionKey = spec.getPartitionKey(); - String explicitHashKey = null; - - // Use custom partitioner - if (partitioner != null) { - partitionKey = partitioner.getPartitionKey(c.element()); - explicitHashKey = partitioner.getExplicitHashKey(c.element()); - } - - ListenableFuture f = - producer.addUserRecord(spec.getStreamName(), partitionKey, explicitHashKey, data); - putFutures.add(f); - } - - @FinishBundle - public void finishBundle() throws Exception { - flushBundle(); - } - - /** - * Flush outstanding records until the total number of failed records will be less than 0 or - * the number of retries will be exhausted. The retry timeout starts from 1 second and it - * doubles on every iteration. - */ - private void flushBundle() throws InterruptedException, ExecutionException, IOException { - int retries = spec.getRetries(); - int numFailedRecords; - int retryTimeout = 1000; // initial timeout, 1 sec - String message = ""; - - do { - numFailedRecords = 0; - producer.flush(); - - // Wait for puts to finish and check the results - for (Future f : putFutures) { - UserRecordResult result = f.get(); // this does block - if (!result.isSuccessful()) { - numFailedRecords++; - } - } - - // wait until outstanding records will be flushed - Thread.sleep(retryTimeout); - retryTimeout *= 2; // exponential backoff - } while (numFailedRecords > 0 && retries-- > 0); - - if (numFailedRecords > 0) { - for (Future f : putFutures) { - UserRecordResult result = f.get(); - if (!result.isSuccessful()) { - failures.offer( - new KinesisWriteException( - "Put record was not successful.", new UserRecordFailedException(result))); - } - } - - LOG.error( - "After [{}] retries, number of failed records [{}] is still greater than 0", - spec.getRetries(), - numFailedRecords); - } - - checkForFailures(message); - } - - /** If any write has asynchronously failed, fail the bundle with a useful error. */ - private void checkForFailures(String message) throws IOException { - if (failures.isEmpty()) { - return; - } - - StringBuilder logEntry = new StringBuilder(); - logEntry.append(message).append(System.lineSeparator()); - - int i = 0; - while (!failures.isEmpty()) { - i++; - KinesisWriteException exc = failures.remove(); - - logEntry.append(System.lineSeparator()).append(exc.getMessage()); - Throwable cause = exc.getCause(); - if (cause != null) { - logEntry.append(": ").append(cause.getMessage()); - - if (cause instanceof UserRecordFailedException) { - List attempts = - ((UserRecordFailedException) cause).getResult().getAttempts(); - for (Attempt attempt : attempts) { - if (attempt.getErrorMessage() != null) { - logEntry.append(System.lineSeparator()).append(attempt.getErrorMessage()); - } - } - } - } - } - - String errorMessage = - String.format( - "Some errors occurred writing to Kinesis. First %d errors: %s", - i, logEntry.toString()); - throw new IOException(errorMessage); - } - - @Teardown - public void teardown() throws Exception { - teardownSharedProducer(); - } - } - } - - /** An exception that puts information about the failed record. */ - static class KinesisWriteException extends IOException { - KinesisWriteException(String message, Throwable cause) { - super(message, cause); - } - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisPartitioner.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisPartitioner.java deleted file mode 100644 index 9bd46eaef682..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisPartitioner.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.io.Serializable; - -/** Kinesis interface for custom partitioner. */ -public interface KinesisPartitioner extends Serializable { - String getPartitionKey(byte[] value); - - String getExplicitHashKey(byte[] value); -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java deleted file mode 100644 index a4a935eed7b9..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import java.io.IOException; -import java.util.NoSuchElementException; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Reads data from multiple kinesis shards in a single thread. It uses simple round robin algorithm - * when fetching data from shards. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class KinesisReader extends UnboundedSource.UnboundedReader { - - private static final Logger LOG = LoggerFactory.getLogger(KinesisReader.class); - - private final SimplifiedKinesisClient kinesis; - private final KinesisSource source; - private final CheckpointGenerator initialCheckpointGenerator; - private final WatermarkPolicyFactory watermarkPolicyFactory; - private final RateLimitPolicyFactory rateLimitPolicyFactory; - private final Duration upToDateThreshold; - private final Duration backlogBytesCheckThreshold; - private CustomOptional currentRecord = CustomOptional.absent(); - private long lastBacklogBytes; - private Instant backlogBytesLastCheckTime = new Instant(0L); - private ShardReadersPool shardReadersPool; - private final Integer maxCapacityPerShard; - - KinesisReader( - SimplifiedKinesisClient kinesis, - CheckpointGenerator initialCheckpointGenerator, - KinesisSource source, - WatermarkPolicyFactory watermarkPolicyFactory, - RateLimitPolicyFactory rateLimitPolicyFactory, - Duration upToDateThreshold, - Integer maxCapacityPerShard) { - this( - kinesis, - initialCheckpointGenerator, - source, - watermarkPolicyFactory, - rateLimitPolicyFactory, - upToDateThreshold, - Duration.standardSeconds(30), - maxCapacityPerShard); - } - - KinesisReader( - SimplifiedKinesisClient kinesis, - CheckpointGenerator initialCheckpointGenerator, - KinesisSource source, - WatermarkPolicyFactory watermarkPolicyFactory, - RateLimitPolicyFactory rateLimitPolicyFactory, - Duration upToDateThreshold, - Duration backlogBytesCheckThreshold, - Integer maxCapacityPerShard) { - this.kinesis = checkNotNull(kinesis, "kinesis"); - this.initialCheckpointGenerator = - checkNotNull(initialCheckpointGenerator, "initialCheckpointGenerator"); - this.watermarkPolicyFactory = watermarkPolicyFactory; - this.rateLimitPolicyFactory = rateLimitPolicyFactory; - this.source = source; - this.upToDateThreshold = upToDateThreshold; - this.backlogBytesCheckThreshold = backlogBytesCheckThreshold; - this.maxCapacityPerShard = maxCapacityPerShard; - } - - /** Generates initial checkpoint and instantiates iterators for shards. */ - @Override - public boolean start() throws IOException { - LOG.info("Starting reader using {}", initialCheckpointGenerator); - - try { - shardReadersPool = createShardReadersPool(); - shardReadersPool.start(); - } catch (TransientKinesisException e) { - throw new IOException(e); - } - - return advance(); - } - - /** Retrieves next record from internal buffer. */ - @Override - public boolean advance() throws IOException { - currentRecord = shardReadersPool.nextRecord(); - return currentRecord.isPresent(); - } - - @Override - public byte[] getCurrentRecordId() throws NoSuchElementException { - return currentRecord.get().getUniqueId(); - } - - @Override - public KinesisRecord getCurrent() throws NoSuchElementException { - return currentRecord.get(); - } - - /** - * Returns the approximate time that the current record was inserted into the stream. It is not - * guaranteed to be accurate - this could lead to mark some records as "late" even if they were - * not. Beware of this when setting {@link - * org.apache.beam.sdk.values.WindowingStrategy#withAllowedLateness} - */ - @Override - public Instant getCurrentTimestamp() throws NoSuchElementException { - return currentRecord.get().getApproximateArrivalTimestamp(); - } - - @Override - public void close() throws IOException { - shardReadersPool.stop(); - } - - @Override - public Instant getWatermark() { - return shardReadersPool.getWatermark(); - } - - @Override - public UnboundedSource.CheckpointMark getCheckpointMark() { - return shardReadersPool.getCheckpointMark(); - } - - @Override - public UnboundedSource getCurrentSource() { - return source; - } - - /** - * Returns total size of all records that remain in Kinesis stream. The size is estimated taking - * into account size of the records that were added to the stream after timestamp of the most - * recent record returned by the reader. If no records have yet been retrieved from the reader - * {@link UnboundedSource.UnboundedReader#BACKLOG_UNKNOWN} is returned. When currently processed - * record is not further behind than {@link #upToDateThreshold} then this method returns 0. - * - *

The method can over-estimate size of the records for the split as it reports the backlog - * across all shards. This can lead to unnecessary decisions to scale up the number of workers but - * will never fail to scale up when this is necessary due to backlog size. - * - * @see BEAM-9439 - */ - @Override - public long getSplitBacklogBytes() { - Instant latestRecordTimestamp = shardReadersPool.getLatestRecordTimestamp(); - - if (latestRecordTimestamp.equals(BoundedWindow.TIMESTAMP_MIN_VALUE)) { - LOG.debug("Split backlog bytes for stream {} unknown", source.getStreamName()); - return UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN; - } - - if (latestRecordTimestamp.plus(upToDateThreshold).isAfterNow()) { - LOG.debug( - "Split backlog bytes for stream {} with latest record timestamp {}: 0 (latest record timestamp is up-to-date with threshold of {})", - source.getStreamName(), - latestRecordTimestamp, - upToDateThreshold); - return 0L; - } - - if (backlogBytesLastCheckTime.plus(backlogBytesCheckThreshold).isAfterNow()) { - LOG.debug( - "Split backlog bytes for {} stream with latest record timestamp {}: {} (cached value)", - source.getStreamName(), - latestRecordTimestamp, - lastBacklogBytes); - return lastBacklogBytes; - } - - try { - lastBacklogBytes = kinesis.getBacklogBytes(source.getStreamName(), latestRecordTimestamp); - backlogBytesLastCheckTime = Instant.now(); - } catch (TransientKinesisException e) { - LOG.warn( - "Transient exception occurred during backlog estimation for stream {}.", - source.getStreamName(), - e); - } - LOG.info( - "Split backlog bytes for {} stream with {} latest record timestamp: {}", - source.getStreamName(), - latestRecordTimestamp, - lastBacklogBytes); - return lastBacklogBytes; - } - - ShardReadersPool createShardReadersPool() throws TransientKinesisException { - return new ShardReadersPool( - kinesis, - initialCheckpointGenerator.generate(kinesis), - watermarkPolicyFactory, - rateLimitPolicyFactory, - maxCapacityPerShard); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java deleted file mode 100644 index 4b4bcc3898c7..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.newArrayList; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.partition; - -import java.io.IOException; -import java.io.Serializable; -import java.util.Iterator; -import java.util.List; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; - -/** - * Checkpoint representing a total progress in a set of shards in single stream. The set of shards - * covered by {@link KinesisReaderCheckpoint} may or may not be equal to set of all shards present - * in the stream. This class is immutable. - */ -class KinesisReaderCheckpoint - implements Iterable, UnboundedSource.CheckpointMark, Serializable { - - private final List shardCheckpoints; - - public KinesisReaderCheckpoint(Iterable shardCheckpoints) { - this.shardCheckpoints = ImmutableList.copyOf(shardCheckpoints); - } - - /** - * Splits given multi-shard checkpoint into partitions of approximately equal size. - * - * @param desiredNumSplits - upper limit for number of partitions to generate. - * @return list of checkpoints covering consecutive partitions of current checkpoint. - */ - public List splitInto(int desiredNumSplits) { - int partitionSize = divideAndRoundUp(shardCheckpoints.size(), desiredNumSplits); - - List checkpoints = newArrayList(); - for (List shardPartition : partition(shardCheckpoints, partitionSize)) { - checkpoints.add(new KinesisReaderCheckpoint(shardPartition)); - } - return checkpoints; - } - - private int divideAndRoundUp(int nominator, int denominator) { - return (nominator + denominator - 1) / denominator; - } - - String getStreamName() { - Iterator iterator = iterator(); - return iterator.hasNext() ? iterator.next().getStreamName() : "[unknown]"; - } - - @Override - public void finalizeCheckpoint() throws IOException {} - - @Override - public String toString() { - return shardCheckpoints.toString(); - } - - @Override - public Iterator iterator() { - return shardCheckpoints.iterator(); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java deleted file mode 100644 index 381ee0d81064..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.commons.lang3.builder.HashCodeBuilder.reflectionHashCode; - -import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; -import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import org.apache.commons.lang.builder.EqualsBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Instant; - -/** {@link UserRecord} enhanced with utility methods. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class KinesisRecord { - - private Instant readTime; - private String streamName; - private String shardId; - private long subSequenceNumber; - private String sequenceNumber; - private Instant approximateArrivalTimestamp; - private ByteBuffer data; - private String partitionKey; - - public KinesisRecord(UserRecord record, String streamName, String shardId) { - this( - record.getData(), - record.getSequenceNumber(), - record.getSubSequenceNumber(), - record.getPartitionKey(), - new Instant(record.getApproximateArrivalTimestamp()), - Instant.now(), - streamName, - shardId); - } - - public KinesisRecord( - ByteBuffer data, - String sequenceNumber, - long subSequenceNumber, - String partitionKey, - Instant approximateArrivalTimestamp, - Instant readTime, - String streamName, - String shardId) { - this.data = data; - this.sequenceNumber = sequenceNumber; - this.subSequenceNumber = subSequenceNumber; - this.partitionKey = partitionKey; - this.approximateArrivalTimestamp = approximateArrivalTimestamp; - this.readTime = readTime; - this.streamName = streamName; - this.shardId = shardId; - } - - public ExtendedSequenceNumber getExtendedSequenceNumber() { - return new ExtendedSequenceNumber(getSequenceNumber(), getSubSequenceNumber()); - } - - /** @return The unique identifier of the record based on its position in the stream. */ - public byte[] getUniqueId() { - return getExtendedSequenceNumber().toString().getBytes(StandardCharsets.UTF_8); - } - - public Instant getReadTime() { - return readTime; - } - - public String getStreamName() { - return streamName; - } - - public String getShardId() { - return shardId; - } - - public byte[] getDataAsBytes() { - return getData().array(); - } - - @Override - public boolean equals(@Nullable Object obj) { - return EqualsBuilder.reflectionEquals(this, obj); - } - - @Override - public int hashCode() { - return reflectionHashCode(this); - } - - public long getSubSequenceNumber() { - return subSequenceNumber; - } - - /** @return The unique identifier of the record within its shard. */ - public String getSequenceNumber() { - return sequenceNumber; - } - - /** @return The approximate time that the record was inserted into the stream. */ - public Instant getApproximateArrivalTimestamp() { - return approximateArrivalTimestamp; - } - - /** @return The data blob. */ - public ByteBuffer getData() { - return data; - } - - public String getPartitionKey() { - return partitionKey; - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java deleted file mode 100644 index efe4d2346797..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.ByteArrayCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.InstantCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarLongCoder; -import org.joda.time.Instant; - -/** A {@link Coder} for {@link KinesisRecord}. */ -class KinesisRecordCoder extends AtomicCoder { - - private static final StringUtf8Coder STRING_CODER = StringUtf8Coder.of(); - private static final ByteArrayCoder BYTE_ARRAY_CODER = ByteArrayCoder.of(); - private static final InstantCoder INSTANT_CODER = InstantCoder.of(); - private static final VarLongCoder VAR_LONG_CODER = VarLongCoder.of(); - - public static KinesisRecordCoder of() { - return new KinesisRecordCoder(); - } - - @Override - public void encode(KinesisRecord value, OutputStream outStream) throws IOException { - BYTE_ARRAY_CODER.encode(value.getData().array(), outStream); - STRING_CODER.encode(value.getSequenceNumber(), outStream); - STRING_CODER.encode(value.getPartitionKey(), outStream); - INSTANT_CODER.encode(value.getApproximateArrivalTimestamp(), outStream); - VAR_LONG_CODER.encode(value.getSubSequenceNumber(), outStream); - INSTANT_CODER.encode(value.getReadTime(), outStream); - STRING_CODER.encode(value.getStreamName(), outStream); - STRING_CODER.encode(value.getShardId(), outStream); - } - - @Override - public KinesisRecord decode(InputStream inStream) throws IOException { - ByteBuffer data = ByteBuffer.wrap(BYTE_ARRAY_CODER.decode(inStream)); - String sequenceNumber = STRING_CODER.decode(inStream); - String partitionKey = STRING_CODER.decode(inStream); - Instant approximateArrivalTimestamp = INSTANT_CODER.decode(inStream); - long subSequenceNumber = VAR_LONG_CODER.decode(inStream); - Instant readTimestamp = INSTANT_CODER.decode(inStream); - String streamName = STRING_CODER.decode(inStream); - String shardId = STRING_CODER.decode(inStream); - return new KinesisRecord( - data, - sequenceNumber, - subSequenceNumber, - partitionKey, - approximateArrivalTimestamp, - readTimestamp, - streamName, - shardId); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisShardClosedException.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisShardClosedException.java deleted file mode 100644 index 322b78a418e9..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisShardClosedException.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -/** Internal exception thrown when shard end is encountered during iteration. */ -class KinesisShardClosedException extends Exception { - - KinesisShardClosedException(String message) { - super(message); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java deleted file mode 100644 index e53d71ed0b81..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.newArrayList; - -import java.util.List; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.SerializableCoder; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.sdk.options.PipelineOptions; -import org.joda.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** Represents source for single stream in Kinesis. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class KinesisSource extends UnboundedSource { - - private static final Logger LOG = LoggerFactory.getLogger(KinesisSource.class); - - private final AWSClientsProvider awsClientsProvider; - private final String streamName; - private final Duration upToDateThreshold; - private final WatermarkPolicyFactory watermarkPolicyFactory; - private final RateLimitPolicyFactory rateLimitPolicyFactory; - private CheckpointGenerator initialCheckpointGenerator; - private final Integer limit; - private final Integer maxCapacityPerShard; - - KinesisSource( - AWSClientsProvider awsClientsProvider, - String streamName, - StartingPoint startingPoint, - Duration upToDateThreshold, - WatermarkPolicyFactory watermarkPolicyFactory, - RateLimitPolicyFactory rateLimitPolicyFactory, - Integer limit, - Integer maxCapacityPerShard) { - this( - awsClientsProvider, - new DynamicCheckpointGenerator(streamName, startingPoint), - streamName, - upToDateThreshold, - watermarkPolicyFactory, - rateLimitPolicyFactory, - limit, - maxCapacityPerShard); - } - - private KinesisSource( - AWSClientsProvider awsClientsProvider, - CheckpointGenerator initialCheckpoint, - String streamName, - Duration upToDateThreshold, - WatermarkPolicyFactory watermarkPolicyFactory, - RateLimitPolicyFactory rateLimitPolicyFactory, - Integer limit, - Integer maxCapacityPerShard) { - this.awsClientsProvider = awsClientsProvider; - this.initialCheckpointGenerator = initialCheckpoint; - this.streamName = streamName; - this.upToDateThreshold = upToDateThreshold; - this.watermarkPolicyFactory = watermarkPolicyFactory; - this.rateLimitPolicyFactory = rateLimitPolicyFactory; - this.limit = limit; - this.maxCapacityPerShard = maxCapacityPerShard; - validate(); - } - - /** - * Generate splits for reading from the stream. Basically, it'll try to evenly split set of shards - * in the stream into {@code desiredNumSplits} partitions. Each partition is then a split. - */ - @Override - public List split(int desiredNumSplits, PipelineOptions options) throws Exception { - KinesisReaderCheckpoint checkpoint = - initialCheckpointGenerator.generate( - SimplifiedKinesisClient.from(awsClientsProvider, limit)); - - List sources = newArrayList(); - - for (KinesisReaderCheckpoint partition : checkpoint.splitInto(desiredNumSplits)) { - sources.add( - new KinesisSource( - awsClientsProvider, - new StaticCheckpointGenerator(partition), - streamName, - upToDateThreshold, - watermarkPolicyFactory, - rateLimitPolicyFactory, - limit, - maxCapacityPerShard)); - } - return sources; - } - - /** - * Creates reader based on given {@link KinesisReaderCheckpoint}. If {@link - * KinesisReaderCheckpoint} is not given, then we use {@code initialCheckpointGenerator} to - * generate new checkpoint. - */ - @Override - public UnboundedReader createReader( - PipelineOptions options, KinesisReaderCheckpoint checkpointMark) { - - CheckpointGenerator checkpointGenerator = initialCheckpointGenerator; - - if (checkpointMark != null) { - checkpointGenerator = new StaticCheckpointGenerator(checkpointMark); - } - - LOG.info("Creating new reader using {}", checkpointGenerator); - - return new KinesisReader( - SimplifiedKinesisClient.from(awsClientsProvider, limit), - checkpointGenerator, - this, - watermarkPolicyFactory, - rateLimitPolicyFactory, - upToDateThreshold, - maxCapacityPerShard); - } - - @Override - public Coder getCheckpointMarkCoder() { - return SerializableCoder.of(KinesisReaderCheckpoint.class); - } - - @Override - public void validate() { - checkNotNull(awsClientsProvider); - checkNotNull(initialCheckpointGenerator); - checkNotNull(watermarkPolicyFactory); - checkNotNull(rateLimitPolicyFactory); - } - - @Override - public Coder getOutputCoder() { - return KinesisRecordCoder.of(); - } - - String getStreamName() { - return streamName; - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java deleted file mode 100644 index 8ee1e81558f7..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.util.List; - -public interface RateLimitPolicy { - - /** - * Called after Kinesis records are successfully retrieved. - * - * @param records The list of retrieved records. - */ - default void onSuccess(List records) throws InterruptedException {} - - /** - * Called after the Kinesis client is throttled. - * - * @param e The {@code KinesisClientThrottledException} thrown by the client. - */ - default void onThrottle(KinesisClientThrottledException e) throws InterruptedException {} -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java deleted file mode 100644 index 12e013136abc..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.io.IOException; -import java.io.Serializable; -import java.util.List; -import java.util.function.Supplier; -import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.sdk.util.BackOffUtils; -import org.apache.beam.sdk.util.FluentBackoff; -import org.apache.beam.sdk.util.Sleeper; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.joda.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Implement this interface to create a {@code RateLimitPolicy}. Used to create a rate limiter for - * each shard. The factory will be called from multiple threads, so if it returns a singleton - * instance of RateLimitPolicy then that instance should be thread-safe, otherwise it should return - * separate RateLimitPolicy instances. - */ -public interface RateLimitPolicyFactory extends Serializable { - - RateLimitPolicy getRateLimitPolicy(); - - static RateLimitPolicyFactory withoutLimiter() { - return () -> new RateLimitPolicy() {}; - } - - static RateLimitPolicyFactory withDefaultRateLimiter() { - return withDefaultRateLimiter( - Duration.millis(100), Duration.millis(500), Duration.standardSeconds(1)); - } - - static RateLimitPolicyFactory withDefaultRateLimiter( - Duration emptySuccessBaseDelay, Duration throttledBaseDelay, Duration maxDelay) { - return () -> new DefaultRateLimiter(emptySuccessBaseDelay, throttledBaseDelay, maxDelay); - } - - static RateLimitPolicyFactory withFixedDelay() { - return DelayIntervalRateLimiter::new; - } - - static RateLimitPolicyFactory withFixedDelay(Duration delay) { - return () -> new DelayIntervalRateLimiter(() -> delay); - } - - static RateLimitPolicyFactory withDelay(Supplier delay) { - return () -> new DelayIntervalRateLimiter(delay); - } - - class DelayIntervalRateLimiter implements RateLimitPolicy { - - private static final Supplier DEFAULT_DELAY = () -> Duration.standardSeconds(1); - - private final Supplier delay; - - public DelayIntervalRateLimiter() { - this(DEFAULT_DELAY); - } - - public DelayIntervalRateLimiter(Supplier delay) { - this.delay = delay; - } - - @Override - public void onSuccess(List records) throws InterruptedException { - Thread.sleep(delay.get().getMillis()); - } - } - - /** - * Default rate limiter that throttles reading from a shard using an exponential backoff if the - * response is empty or if the consumer is throttled by AWS. - */ - class DefaultRateLimiter implements RateLimitPolicy { - private static final Logger LOG = LoggerFactory.getLogger(DefaultRateLimiter.class); - private final Sleeper sleeper; - private final BackOff throttled; - private final BackOff emptySuccess; - - @VisibleForTesting - DefaultRateLimiter(BackOff emptySuccess, BackOff throttled, Sleeper sleeper) { - this.emptySuccess = emptySuccess; - this.throttled = throttled; - this.sleeper = sleeper; - } - - public DefaultRateLimiter(BackOff emptySuccess, BackOff throttled) { - this(emptySuccess, throttled, Sleeper.DEFAULT); - } - - public DefaultRateLimiter( - Duration emptySuccessBaseDelay, Duration throttledBaseDelay, Duration maxDelay) { - this( - FluentBackoff.DEFAULT - .withInitialBackoff(emptySuccessBaseDelay) - .withMaxBackoff(maxDelay) - .backoff(), - FluentBackoff.DEFAULT - .withInitialBackoff(throttledBaseDelay) - .withMaxBackoff(maxDelay) - .backoff()); - } - - @Override - public void onSuccess(List records) throws InterruptedException { - try { - if (records.isEmpty()) { - BackOffUtils.next(sleeper, emptySuccess); - } else { - emptySuccess.reset(); - } - throttled.reset(); - } catch (IOException e) { - LOG.warn("Error applying onSuccess rate limit policy", e); - } - } - - @Override - public void onThrottle(KinesisClientThrottledException e) throws InterruptedException { - try { - BackOffUtils.next(sleeper, throttled); - } catch (IOException ioe) { - LOG.warn("Error applying onThrottle rate limit policy", e); - } - } - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RecordFilter.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RecordFilter.java deleted file mode 100644 index 2a0456e04052..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RecordFilter.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.newArrayList; - -import java.util.List; - -/** - * Filters out records, which were already processed and checkpointed. - * - *

We need this step, because we can get iterators from Kinesis only with "sequenceNumber" - * accuracy, not with "subSequenceNumber" accuracy. - */ -class RecordFilter { - - public List apply(List records, ShardCheckpoint checkpoint) { - List filteredRecords = newArrayList(); - for (KinesisRecord record : records) { - if (checkpoint.isBeforeOrAt(record)) { - filteredRecords.add(record); - } - } - return filteredRecords; - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java deleted file mode 100644 index b185a396d1fd..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static com.amazonaws.services.kinesis.model.ShardIteratorType.AFTER_SEQUENCE_NUMBER; -import static com.amazonaws.services.kinesis.model.ShardIteratorType.AT_SEQUENCE_NUMBER; -import static com.amazonaws.services.kinesis.model.ShardIteratorType.AT_TIMESTAMP; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; -import com.amazonaws.services.kinesis.model.Record; -import com.amazonaws.services.kinesis.model.ShardIteratorType; -import java.io.Serializable; -import org.joda.time.Instant; - -/** - * Checkpoint mark for single shard in the stream. Current position in the shard is determined by - * either: - * - *

    - *
  • {@link #shardIteratorType} if it is equal to {@link ShardIteratorType#LATEST} or {@link - * ShardIteratorType#TRIM_HORIZON} - *
  • combination of {@link #sequenceNumber} and {@link #subSequenceNumber} if {@link - * ShardIteratorType#AFTER_SEQUENCE_NUMBER} or {@link ShardIteratorType#AT_SEQUENCE_NUMBER} - *
- * - * This class is immutable. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class ShardCheckpoint implements Serializable { - - private final String streamName; - private final String shardId; - private final String sequenceNumber; - private final ShardIteratorType shardIteratorType; - private final Long subSequenceNumber; - private final Instant timestamp; - - public ShardCheckpoint(String streamName, String shardId, StartingPoint startingPoint) { - this( - streamName, - shardId, - ShardIteratorType.fromValue(startingPoint.getPositionName()), - startingPoint.getTimestamp()); - } - - public ShardCheckpoint( - String streamName, String shardId, ShardIteratorType shardIteratorType, Instant timestamp) { - this(streamName, shardId, shardIteratorType, null, null, timestamp); - } - - public ShardCheckpoint( - String streamName, - String shardId, - ShardIteratorType shardIteratorType, - String sequenceNumber, - Long subSequenceNumber) { - this(streamName, shardId, shardIteratorType, sequenceNumber, subSequenceNumber, null); - } - - private ShardCheckpoint( - String streamName, - String shardId, - ShardIteratorType shardIteratorType, - String sequenceNumber, - Long subSequenceNumber, - Instant timestamp) { - this.shardIteratorType = checkNotNull(shardIteratorType, "shardIteratorType"); - this.streamName = checkNotNull(streamName, "streamName"); - this.shardId = checkNotNull(shardId, "shardId"); - if (shardIteratorType == AT_SEQUENCE_NUMBER || shardIteratorType == AFTER_SEQUENCE_NUMBER) { - checkNotNull( - sequenceNumber, - "You must provide sequence number for AT_SEQUENCE_NUMBER" + " or AFTER_SEQUENCE_NUMBER"); - } else { - checkArgument( - sequenceNumber == null, - "Sequence number must be null for LATEST, TRIM_HORIZON or AT_TIMESTAMP"); - } - if (shardIteratorType == AT_TIMESTAMP) { - checkNotNull(timestamp, "You must provide timestamp for AT_TIMESTAMP"); - } else { - checkArgument( - timestamp == null, "Timestamp must be null for an iterator type other than AT_TIMESTAMP"); - } - - this.subSequenceNumber = subSequenceNumber; - this.sequenceNumber = sequenceNumber; - this.timestamp = timestamp; - } - - /** - * Used to compare {@link ShardCheckpoint} object to {@link KinesisRecord}. Depending on the - * underlying shardIteratorType, it will either compare the timestamp or the {@link - * ExtendedSequenceNumber}. - * - * @param other - * @return if current checkpoint mark points before or at given {@link ExtendedSequenceNumber} - */ - public boolean isBeforeOrAt(KinesisRecord other) { - if (shardIteratorType == AT_TIMESTAMP) { - return timestamp.compareTo(other.getApproximateArrivalTimestamp()) <= 0; - } - int result = extendedSequenceNumber().compareTo(other.getExtendedSequenceNumber()); - if (result == 0) { - return shardIteratorType == AT_SEQUENCE_NUMBER; - } - return result < 0; - } - - private ExtendedSequenceNumber extendedSequenceNumber() { - String fullSequenceNumber = sequenceNumber; - if (fullSequenceNumber == null) { - fullSequenceNumber = shardIteratorType.toString(); - } - return new ExtendedSequenceNumber(fullSequenceNumber, subSequenceNumber); - } - - @Override - public String toString() { - return String.format( - "Checkpoint %s for stream %s, shard %s: %s", - shardIteratorType, streamName, shardId, sequenceNumber); - } - - public String getShardIterator(SimplifiedKinesisClient kinesisClient) - throws TransientKinesisException { - if (checkpointIsInTheMiddleOfAUserRecord()) { - return kinesisClient.getShardIterator( - streamName, shardId, AT_SEQUENCE_NUMBER, sequenceNumber, null); - } - return kinesisClient.getShardIterator( - streamName, shardId, shardIteratorType, sequenceNumber, timestamp); - } - - private boolean checkpointIsInTheMiddleOfAUserRecord() { - return shardIteratorType == AFTER_SEQUENCE_NUMBER && subSequenceNumber != null; - } - - /** - * Used to advance checkpoint mark to position after given {@link Record}. - * - * @param record - * @return new checkpoint object pointing directly after given {@link Record} - */ - public ShardCheckpoint moveAfter(KinesisRecord record) { - return new ShardCheckpoint( - streamName, - shardId, - AFTER_SEQUENCE_NUMBER, - record.getSequenceNumber(), - record.getSubSequenceNumber()); - } - - public String getStreamName() { - return streamName; - } - - public String getShardId() { - return shardId; - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java deleted file mode 100644 index 703d10d3640e..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java +++ /dev/null @@ -1,394 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import java.util.Collection; -import java.util.Comparator; -import java.util.List; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Internal shard iterators pool. It maintains the thread pool for reading Kinesis shards in - * separate threads. Read records are stored in a blocking queue of limited capacity. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class ShardReadersPool { - - private static final Logger LOG = LoggerFactory.getLogger(ShardReadersPool.class); - public static final int DEFAULT_CAPACITY_PER_SHARD = 10_000; - private static final int ATTEMPTS_TO_SHUTDOWN = 3; - private static final int QUEUE_OFFER_TIMEOUT_MS = 500; - private static final int QUEUE_POLL_TIMEOUT_MS = 1000; - - /** - * Executor service for running the threads that read records from shards handled by this pool. - * Each thread runs the {@link ShardReadersPool#readLoop(ShardRecordsIterator, RateLimitPolicy)} - * method and handles exactly one shard. - */ - private final ExecutorService executorService; - - /** - * A Bounded buffer for read records. Records are added to this buffer within {@link - * ShardReadersPool#readLoop(ShardRecordsIterator, RateLimitPolicy)} method and removed in {@link - * ShardReadersPool#nextRecord()}. - */ - private BlockingQueue recordsQueue; - - /** - * A reference to an immutable mapping of {@link ShardRecordsIterator} instances to shard ids. - * This map is replaced with a new one when resharding operation on any handled shard occurs. - */ - private final AtomicReference> shardIteratorsMap; - - /** A map for keeping the current number of records stored in a buffer per shard. */ - private final ConcurrentMap numberOfRecordsInAQueueByShard; - - private final SimplifiedKinesisClient kinesis; - private final WatermarkPolicyFactory watermarkPolicyFactory; - private final RateLimitPolicyFactory rateLimitPolicyFactory; - private final KinesisReaderCheckpoint initialCheckpoint; - private final int queueCapacityPerShard; - private final AtomicBoolean poolOpened = new AtomicBoolean(true); - - ShardReadersPool( - SimplifiedKinesisClient kinesis, - KinesisReaderCheckpoint initialCheckpoint, - WatermarkPolicyFactory watermarkPolicyFactory, - RateLimitPolicyFactory rateLimitPolicyFactory, - int queueCapacityPerShard) { - this.kinesis = kinesis; - this.initialCheckpoint = initialCheckpoint; - this.watermarkPolicyFactory = watermarkPolicyFactory; - this.rateLimitPolicyFactory = rateLimitPolicyFactory; - this.queueCapacityPerShard = queueCapacityPerShard; - this.executorService = Executors.newCachedThreadPool(); - this.numberOfRecordsInAQueueByShard = new ConcurrentHashMap<>(); - this.shardIteratorsMap = new AtomicReference<>(); - } - - void start() throws TransientKinesisException { - ImmutableMap.Builder shardsMap = ImmutableMap.builder(); - for (ShardCheckpoint checkpoint : initialCheckpoint) { - shardsMap.put(checkpoint.getShardId(), createShardIterator(kinesis, checkpoint)); - } - shardIteratorsMap.set(shardsMap.build()); - if (!shardIteratorsMap.get().isEmpty()) { - recordsQueue = - new ArrayBlockingQueue<>(queueCapacityPerShard * shardIteratorsMap.get().size()); - String streamName = initialCheckpoint.getStreamName(); - startReadingShards(shardIteratorsMap.get().values(), streamName); - } else { - // There are no shards to handle when restoring from an empty checkpoint. Empty checkpoints - // are generated when the last shard handled by this pool was closed - recordsQueue = new ArrayBlockingQueue<>(1); - } - } - - // Note: readLoop() will log any Throwable raised so opt to ignore the future result - @SuppressWarnings("FutureReturnValueIgnored") - void startReadingShards(Iterable shardRecordsIterators, String streamName) { - if (!shardRecordsIterators.iterator().hasNext()) { - LOG.info("Stream {} will not be read, no shard records iterators available", streamName); - return; - } - LOG.info( - "Starting to read {} stream from {} shards", - streamName, - getShardIdsFromRecordsIterators(shardRecordsIterators)); - for (final ShardRecordsIterator recordsIterator : shardRecordsIterators) { - numberOfRecordsInAQueueByShard.put(recordsIterator.getShardId(), new AtomicInteger()); - executorService.submit( - () -> readLoop(recordsIterator, rateLimitPolicyFactory.getRateLimitPolicy())); - } - } - - private void readLoop(ShardRecordsIterator shardRecordsIterator, RateLimitPolicy rateLimiter) { - while (poolOpened.get()) { - try { - try { - List kinesisRecords = shardRecordsIterator.readNextBatch(); - try { - for (KinesisRecord kinesisRecord : kinesisRecords) { - while (true) { - if (!poolOpened.get()) { - return; - } - if (recordsQueue.offer(kinesisRecord, QUEUE_OFFER_TIMEOUT_MS, MILLISECONDS)) { - numberOfRecordsInAQueueByShard.get(kinesisRecord.getShardId()).incrementAndGet(); - break; - } - } - } - } finally { - // One of the paths into this finally block is recordsQueue.put() throwing - // InterruptedException so we should check the thread's interrupted status before - // calling onSuccess(). - if (!Thread.currentThread().isInterrupted()) { - rateLimiter.onSuccess(kinesisRecords); - } - } - } catch (KinesisShardClosedException e) { - LOG.info( - "Shard iterator for {} shard is closed, finishing the read loop", - shardRecordsIterator.getShardId(), - e); - // Wait until all records from already closed shard are taken from the buffer and only - // then start reading successive shards. This guarantees that checkpoints will contain - // either parent or child shard and never both. Such approach allows for more - // straightforward checkpoint restoration than in a case when new shards are read - // immediately. - waitUntilAllShardRecordsRead(shardRecordsIterator); - readFromSuccessiveShards(shardRecordsIterator); - break; - } - } catch (KinesisClientThrottledException e) { - try { - rateLimiter.onThrottle(e); - } catch (InterruptedException ex) { - LOG.warn("Thread was interrupted, finishing the read loop", ex); - Thread.currentThread().interrupt(); - break; - } - } catch (TransientKinesisException e) { - LOG.warn("Transient exception occurred.", e); - } catch (InterruptedException e) { - LOG.warn("Thread was interrupted, finishing the read loop", e); - Thread.currentThread().interrupt(); - break; - } catch (Throwable e) { - LOG.error("Unexpected exception occurred", e); - } - } - LOG.info("Kinesis Shard read loop has finished"); - } - - CustomOptional nextRecord() { - try { - KinesisRecord record = recordsQueue.poll(QUEUE_POLL_TIMEOUT_MS, MILLISECONDS); - if (record == null) { - return CustomOptional.absent(); - } - shardIteratorsMap.get().get(record.getShardId()).ackRecord(record); - - // numberOfRecordsInAQueueByShard contains the counter for a given shard until the shard is - // closed and then it's counter reaches 0. Thus the access here is safe - numberOfRecordsInAQueueByShard.get(record.getShardId()).decrementAndGet(); - return CustomOptional.of(record); - } catch (InterruptedException e) { - LOG.warn("Interrupted while waiting for KinesisRecord from the buffer"); - return CustomOptional.absent(); - } - } - - void stop() { - LOG.info("Closing shard iterators pool"); - poolOpened.set(false); - executorService.shutdown(); - awaitTermination(); - if (!executorService.isTerminated()) { - LOG.warn( - "Executor service was not completely terminated after {} attempts, trying to forcibly stop it.", - ATTEMPTS_TO_SHUTDOWN); - executorService.shutdownNow(); - awaitTermination(); - } - } - - private void awaitTermination() { - int attemptsLeft = ATTEMPTS_TO_SHUTDOWN; - boolean isTerminated = executorService.isTerminated(); - - while (!isTerminated && attemptsLeft-- > 0) { - try { - isTerminated = executorService.awaitTermination(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - LOG.error("Interrupted while waiting for the executor service to shutdown"); - throw new RuntimeException(e); - } - if (!isTerminated && attemptsLeft > 0) { - LOG.warn( - "Executor service is taking long time to shutdown, will retry. {} attempts left", - attemptsLeft); - } - } - } - - Instant getWatermark() { - return getMinTimestamp(ShardRecordsIterator::getShardWatermark); - } - - Instant getLatestRecordTimestamp() { - return getMinTimestamp(ShardRecordsIterator::getLatestRecordTimestamp); - } - - private Instant getMinTimestamp(Function timestampExtractor) { - return shardIteratorsMap.get().values().stream() - .map(timestampExtractor) - .min(Comparator.naturalOrder()) - .orElse(BoundedWindow.TIMESTAMP_MAX_VALUE); - } - - KinesisReaderCheckpoint getCheckpointMark() { - ImmutableMap currentShardIterators = shardIteratorsMap.get(); - return new KinesisReaderCheckpoint( - currentShardIterators.values().stream() - .map( - shardRecordsIterator -> { - checkArgument( - shardRecordsIterator != null, "shardRecordsIterator can not be null"); - return shardRecordsIterator.getCheckpoint(); - }) - .collect(Collectors.toList())); - } - - ShardRecordsIterator createShardIterator( - SimplifiedKinesisClient kinesis, ShardCheckpoint checkpoint) - throws TransientKinesisException { - return new ShardRecordsIterator(checkpoint, kinesis, watermarkPolicyFactory); - } - - /** - * Waits until all records read from given shardRecordsIterator are taken from {@link - * #recordsQueue} and acked. Uses {@link #numberOfRecordsInAQueueByShard} map to track the amount - * of remaining events. - */ - private void waitUntilAllShardRecordsRead(ShardRecordsIterator shardRecordsIterator) - throws InterruptedException { - // Given shard is already closed so no more records will be read from it. Thus the counter for - // that shard will be strictly decreasing to 0. - AtomicInteger numberOfShardRecordsInAQueue = - numberOfRecordsInAQueueByShard.get(shardRecordsIterator.getShardId()); - while (!(numberOfShardRecordsInAQueue.get() == 0)) { - Thread.sleep(TimeUnit.SECONDS.toMillis(1)); - } - } - - /** - * Tries to find successors of a given shard and start reading them. Each closed shard can have 0, - * 1 or 2 successors - * - *
    - *
  • 0 successors - when shard was merged with another shard and this one is considered - * adjacent by merge operation - *
  • 1 successor - when shard was merged with another shard and this one is considered a - * parent by merge operation - *
  • 2 successors - when shard was split into two shards - *
- * - *

Once shard successors are established, the transition to reading new shards can begin. - * During this operation, the immutable {@link ShardReadersPool#shardIteratorsMap} is replaced - * with a new one holding references to {@link ShardRecordsIterator} instances for open shards - * only. Potentially there might be more shard iterators closing at the same time so {@link - * ShardReadersPool#shardIteratorsMap} is updated in a loop using CAS pattern to keep all the - * updates. Then, the counter for already closed shard is removed from {@link - * ShardReadersPool#numberOfRecordsInAQueueByShard} map. - * - *

Finally when update is finished, new threads are spawned for reading the successive shards. - * The thread that handled reading from already closed shard can finally complete. - */ - private void readFromSuccessiveShards(final ShardRecordsIterator closedShardIterator) - throws TransientKinesisException { - List successiveShardRecordIterators = - closedShardIterator.findSuccessiveShardRecordIterators(); - - ImmutableMap current; - ImmutableMap updated; - do { - current = shardIteratorsMap.get(); - updated = - createMapWithSuccessiveShards( - current, closedShardIterator, successiveShardRecordIterators); - } while (!shardIteratorsMap.compareAndSet(current, updated)); - numberOfRecordsInAQueueByShard.remove(closedShardIterator.getShardId()); - - logSuccessiveShardsFromRecordsIterators(closedShardIterator, successiveShardRecordIterators); - - String streamName = closedShardIterator.getStreamName(); - startReadingShards(successiveShardRecordIterators, streamName); - } - - private static void logSuccessiveShardsFromRecordsIterators( - final ShardRecordsIterator closedShardIterator, - final Collection shardRecordsIterators) { - if (shardRecordsIterators.isEmpty()) { - LOG.info( - "Shard {} for {} stream is closed. Found no successive shards to read from " - + "as it was merged with another shard and this one is considered adjacent by merge operation", - closedShardIterator.getShardId(), - closedShardIterator.getStreamName()); - } else { - LOG.info( - "Shard {} for {} stream is closed, found successive shards to read from: {}", - closedShardIterator.getShardId(), - closedShardIterator.getStreamName(), - getShardIdsFromRecordsIterators(shardRecordsIterators)); - } - } - - private static List getShardIdsFromRecordsIterators( - final Iterable iterators) { - return StreamSupport.stream(iterators.spliterator(), false) - .map(ShardRecordsIterator::getShardId) - .collect(Collectors.toList()); - } - - private ImmutableMap createMapWithSuccessiveShards( - ImmutableMap current, - ShardRecordsIterator closedShardIterator, - List successiveShardRecordIterators) - throws TransientKinesisException { - ImmutableMap.Builder shardsMap = ImmutableMap.builder(); - Iterable allShards = - Iterables.concat(current.values(), successiveShardRecordIterators); - for (ShardRecordsIterator iterator : allShards) { - if (!closedShardIterator.getShardId().equals(iterator.getShardId())) { - shardsMap.put(iterator.getShardId(), iterator); - } - } - return shardsMap.build(); - } - - @VisibleForTesting - BlockingQueue getRecordsQueue() { - return recordsQueue; - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java deleted file mode 100644 index aae179373a2c..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.amazonaws.services.kinesis.model.ExpiredIteratorException; -import com.amazonaws.services.kinesis.model.Shard; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicReference; -import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Iterates over records in a single shard. Records are retrieved in batches via calls to {@link - * ShardRecordsIterator#readNextBatch()}. Client has to confirm processed records by calling {@link - * ShardRecordsIterator#ackRecord(KinesisRecord)} method. - */ -class ShardRecordsIterator { - - private static final Logger LOG = LoggerFactory.getLogger(ShardRecordsIterator.class); - - private final SimplifiedKinesisClient kinesis; - private final RecordFilter filter; - private final String streamName; - private final String shardId; - private final AtomicReference checkpoint; - private final WatermarkPolicy watermarkPolicy; - private final WatermarkPolicyFactory watermarkPolicyFactory; - private final WatermarkPolicy latestRecordTimestampPolicy = - WatermarkPolicyFactory.withArrivalTimePolicy().createWatermarkPolicy(); - private String shardIterator; - - ShardRecordsIterator( - ShardCheckpoint initialCheckpoint, - SimplifiedKinesisClient simplifiedKinesisClient, - WatermarkPolicyFactory watermarkPolicyFactory) - throws TransientKinesisException { - this(initialCheckpoint, simplifiedKinesisClient, watermarkPolicyFactory, new RecordFilter()); - } - - ShardRecordsIterator( - ShardCheckpoint initialCheckpoint, - SimplifiedKinesisClient simplifiedKinesisClient, - WatermarkPolicyFactory watermarkPolicyFactory, - RecordFilter filter) - throws TransientKinesisException { - this.checkpoint = new AtomicReference<>(checkNotNull(initialCheckpoint, "initialCheckpoint")); - this.filter = checkNotNull(filter, "filter"); - this.kinesis = checkNotNull(simplifiedKinesisClient, "simplifiedKinesisClient"); - this.streamName = initialCheckpoint.getStreamName(); - this.shardId = initialCheckpoint.getShardId(); - this.shardIterator = initialCheckpoint.getShardIterator(kinesis); - this.watermarkPolicy = watermarkPolicyFactory.createWatermarkPolicy(); - this.watermarkPolicyFactory = watermarkPolicyFactory; - } - - List readNextBatch() - throws TransientKinesisException, KinesisShardClosedException { - if (shardIterator == null) { - throw new KinesisShardClosedException( - String.format( - "Shard iterator reached end of the shard: streamName=%s, shardId=%s", - streamName, shardId)); - } - GetKinesisRecordsResult response = fetchRecords(); - LOG.debug( - "Fetched {} new records from shard: streamName={}, shardId={}", - response.getRecords().size(), - streamName, - shardId); - - List filteredRecords = filter.apply(response.getRecords(), checkpoint.get()); - return filteredRecords; - } - - private GetKinesisRecordsResult fetchRecords() throws TransientKinesisException { - try { - GetKinesisRecordsResult response = kinesis.getRecords(shardIterator, streamName, shardId); - shardIterator = response.getNextShardIterator(); - return response; - } catch (ExpiredIteratorException e) { - LOG.info( - "Refreshing expired iterator for shard: streamName={}, shardId={}", - streamName, - shardId, - e); - shardIterator = checkpoint.get().getShardIterator(kinesis); - return fetchRecords(); - } - } - - ShardCheckpoint getCheckpoint() { - return checkpoint.get(); - } - - void ackRecord(KinesisRecord record) { - checkpoint.set(checkpoint.get().moveAfter(record)); - watermarkPolicy.update(record); - latestRecordTimestampPolicy.update(record); - } - - Instant getShardWatermark() { - return watermarkPolicy.getWatermark(); - } - - Instant getLatestRecordTimestamp() { - return latestRecordTimestampPolicy.getWatermark(); - } - - String getShardId() { - return shardId; - } - - String getStreamName() { - return streamName; - } - - List findSuccessiveShardRecordIterators() throws TransientKinesisException { - List shards = kinesis.listShardsFollowingClosedShard(streamName, shardId); - List successiveShardRecordIterators = new ArrayList<>(); - for (Shard shard : shards) { - if (shardId.equals(shard.getParentShardId())) { - ShardCheckpoint shardCheckpoint = - new ShardCheckpoint( - streamName, - shard.getShardId(), - new StartingPoint(InitialPositionInStream.TRIM_HORIZON)); - successiveShardRecordIterators.add( - new ShardRecordsIterator(shardCheckpoint, kinesis, watermarkPolicyFactory)); - } - } - return successiveShardRecordIterators; - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java deleted file mode 100644 index 88fcc7fcec35..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java +++ /dev/null @@ -1,359 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.AmazonClientException; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.model.Datapoint; -import com.amazonaws.services.cloudwatch.model.Dimension; -import com.amazonaws.services.cloudwatch.model.GetMetricStatisticsRequest; -import com.amazonaws.services.cloudwatch.model.GetMetricStatisticsResult; -import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord; -import com.amazonaws.services.kinesis.model.DescribeStreamSummaryRequest; -import com.amazonaws.services.kinesis.model.ExpiredIteratorException; -import com.amazonaws.services.kinesis.model.GetRecordsRequest; -import com.amazonaws.services.kinesis.model.GetRecordsResult; -import com.amazonaws.services.kinesis.model.GetShardIteratorRequest; -import com.amazonaws.services.kinesis.model.LimitExceededException; -import com.amazonaws.services.kinesis.model.ListShardsRequest; -import com.amazonaws.services.kinesis.model.ListShardsResult; -import com.amazonaws.services.kinesis.model.ProvisionedThroughputExceededException; -import com.amazonaws.services.kinesis.model.Shard; -import com.amazonaws.services.kinesis.model.ShardFilter; -import com.amazonaws.services.kinesis.model.ShardFilterType; -import com.amazonaws.services.kinesis.model.ShardIteratorType; -import com.amazonaws.services.kinesis.model.StreamDescriptionSummary; -import java.io.IOException; -import java.util.Collections; -import java.util.Date; -import java.util.List; -import java.util.concurrent.Callable; -import java.util.function.Supplier; -import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.sdk.util.BackOffUtils; -import org.apache.beam.sdk.util.FluentBackoff; -import org.apache.beam.sdk.util.Sleeper; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.joda.time.Minutes; - -/** Wraps {@link AmazonKinesis} class providing much simpler interface and proper error handling. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class SimplifiedKinesisClient { - - private static final String KINESIS_NAMESPACE = "AWS/Kinesis"; - private static final String INCOMING_RECORDS_METRIC = "IncomingBytes"; - private static final int PERIOD_GRANULARITY_IN_SECONDS = 60; - private static final String SUM_STATISTIC = "Sum"; - private static final String STREAM_NAME_DIMENSION = "StreamName"; - private static final int LIST_SHARDS_MAX_RESULTS = 1_000; - private static final Duration - SPACING_FOR_TIMESTAMP_LIST_SHARDS_REQUEST_TO_NOT_EXCEED_TRIM_HORIZON = - Duration.standardMinutes(5); - private static final int DESCRIBE_STREAM_SUMMARY_MAX_ATTEMPTS = 10; - private static final Duration DESCRIBE_STREAM_SUMMARY_INITIAL_BACKOFF = - Duration.standardSeconds(1); - - private final AmazonKinesis kinesis; - private final AmazonCloudWatch cloudWatch; - private final Integer limit; - private final Supplier currentInstantSupplier; - - public SimplifiedKinesisClient( - AmazonKinesis kinesis, AmazonCloudWatch cloudWatch, Integer limit) { - this(kinesis, cloudWatch, limit, Instant::now); - } - - SimplifiedKinesisClient( - AmazonKinesis kinesis, - AmazonCloudWatch cloudWatch, - Integer limit, - Supplier currentInstantSupplier) { - this.kinesis = checkNotNull(kinesis, "kinesis"); - this.cloudWatch = checkNotNull(cloudWatch, "cloudWatch"); - this.limit = limit; - this.currentInstantSupplier = currentInstantSupplier; - } - - public static SimplifiedKinesisClient from(AWSClientsProvider provider, Integer limit) { - return new SimplifiedKinesisClient( - provider.getKinesisClient(), provider.getCloudWatchClient(), limit); - } - - public String getShardIterator( - final String streamName, - final String shardId, - final ShardIteratorType shardIteratorType, - final String startingSequenceNumber, - final Instant timestamp) - throws TransientKinesisException { - final Date date = timestamp != null ? timestamp.toDate() : null; - return wrapExceptions( - () -> - kinesis - .getShardIterator( - new GetShardIteratorRequest() - .withStreamName(streamName) - .withShardId(shardId) - .withShardIteratorType(shardIteratorType) - .withStartingSequenceNumber(startingSequenceNumber) - .withTimestamp(date)) - .getShardIterator()); - } - - public List listShardsAtPoint(final String streamName, final StartingPoint startingPoint) - throws TransientKinesisException { - ShardFilter shardFilter = - wrapExceptions(() -> buildShardFilterForStartingPoint(streamName, startingPoint)); - return listShards(streamName, shardFilter); - } - - private ShardFilter buildShardFilterForStartingPoint( - String streamName, StartingPoint startingPoint) throws IOException, InterruptedException { - InitialPositionInStream position = startingPoint.getPosition(); - switch (position) { - case LATEST: - return new ShardFilter().withType(ShardFilterType.AT_LATEST); - case TRIM_HORIZON: - return new ShardFilter().withType(ShardFilterType.AT_TRIM_HORIZON); - case AT_TIMESTAMP: - return buildShardFilterForTimestamp(streamName, startingPoint.getTimestamp()); - default: - throw new IllegalArgumentException( - String.format("Unrecognized '%s' position to create shard filter with", position)); - } - } - - private ShardFilter buildShardFilterForTimestamp( - String streamName, Instant startingPointTimestamp) throws IOException, InterruptedException { - StreamDescriptionSummary streamDescription = describeStreamSummary(streamName); - - Instant streamCreationTimestamp = new Instant(streamDescription.getStreamCreationTimestamp()); - if (streamCreationTimestamp.isAfter(startingPointTimestamp)) { - return new ShardFilter().withType(ShardFilterType.AT_TRIM_HORIZON); - } - - Duration retentionPeriod = Duration.standardHours(streamDescription.getRetentionPeriodHours()); - - Instant streamTrimHorizonTimestamp = - currentInstantSupplier - .get() - .minus(retentionPeriod) - .plus(SPACING_FOR_TIMESTAMP_LIST_SHARDS_REQUEST_TO_NOT_EXCEED_TRIM_HORIZON); - if (startingPointTimestamp.isAfter(streamTrimHorizonTimestamp)) { - return new ShardFilter() - .withType(ShardFilterType.AT_TIMESTAMP) - .withTimestamp(startingPointTimestamp.toDate()); - } else { - return new ShardFilter().withType(ShardFilterType.AT_TRIM_HORIZON); - } - } - - private StreamDescriptionSummary describeStreamSummary(final String streamName) - throws IOException, InterruptedException { - // DescribeStreamSummary has limits that can be hit fairly easily if we are attempting - // to configure multiple KinesisIO inputs in the same account. Retry up to - // DESCRIBE_STREAM_SUMMARY_MAX_ATTEMPTS times if we end up hitting that limit. - // - // Only pass the wrapped exception up once that limit is reached. Use FluentBackoff - // to implement the retry policy. - FluentBackoff retryBackoff = - FluentBackoff.DEFAULT - .withMaxRetries(DESCRIBE_STREAM_SUMMARY_MAX_ATTEMPTS) - .withInitialBackoff(DESCRIBE_STREAM_SUMMARY_INITIAL_BACKOFF); - BackOff backoff = retryBackoff.backoff(); - Sleeper sleeper = Sleeper.DEFAULT; - - DescribeStreamSummaryRequest request = new DescribeStreamSummaryRequest(); - request.setStreamName(streamName); - while (true) { - try { - return kinesis.describeStreamSummary(request).getStreamDescriptionSummary(); - } catch (LimitExceededException exc) { - if (!BackOffUtils.next(sleeper, backoff)) { - throw exc; - } - } - } - } - - public List listShardsFollowingClosedShard( - final String streamName, final String exclusiveStartShardId) - throws TransientKinesisException { - ShardFilter shardFilter = - new ShardFilter() - .withType(ShardFilterType.AFTER_SHARD_ID) - .withShardId(exclusiveStartShardId); - return listShards(streamName, shardFilter); - } - - private List listShards(final String streamName, final ShardFilter shardFilter) - throws TransientKinesisException { - return wrapExceptions( - () -> { - ImmutableList.Builder shardsBuilder = ImmutableList.builder(); - - String currentNextToken = null; - do { - ListShardsRequest request = new ListShardsRequest(); - request.setMaxResults(LIST_SHARDS_MAX_RESULTS); - if (currentNextToken != null) { - request.setNextToken(currentNextToken); - } else { - request.setStreamName(streamName); - } - request.setShardFilter(shardFilter); - - ListShardsResult response = kinesis.listShards(request); - List shards = response.getShards(); - shardsBuilder.addAll(shards); - currentNextToken = response.getNextToken(); - } while (currentNextToken != null); - - return shardsBuilder.build(); - }); - } - - /** - * Gets records from Kinesis and deaggregates them if needed. - * - * @return list of deaggregated records - * @throws TransientKinesisException - in case of recoverable situation - */ - public GetKinesisRecordsResult getRecords(String shardIterator, String streamName, String shardId) - throws TransientKinesisException { - return getRecords(shardIterator, streamName, shardId, limit); - } - - /** - * Gets records from Kinesis and deaggregates them if needed. - * - * @return list of deaggregated records - * @throws TransientKinesisException - in case of recoverable situation - */ - public GetKinesisRecordsResult getRecords( - final String shardIterator, - final String streamName, - final String shardId, - final Integer limit) - throws TransientKinesisException { - return wrapExceptions( - () -> { - GetRecordsResult response = - kinesis.getRecords( - new GetRecordsRequest().withShardIterator(shardIterator).withLimit(limit)); - return new GetKinesisRecordsResult( - UserRecord.deaggregate(response.getRecords()), - response.getNextShardIterator(), - response.getMillisBehindLatest(), - streamName, - shardId); - }); - } - - /** - * Gets total size in bytes of all events that remain in Kinesis stream after specified instant. - * - * @return total size in bytes of all Kinesis events after specified instant - */ - public long getBacklogBytes(String streamName, Instant countSince) - throws TransientKinesisException { - return getBacklogBytes(streamName, countSince, new Instant()); - } - - /** - * Gets total size in bytes of all events that remain in Kinesis stream between specified - * instants. - * - * @return total size in bytes of all Kinesis events after specified instant - */ - public long getBacklogBytes( - final String streamName, final Instant countSince, final Instant countTo) - throws TransientKinesisException { - return wrapExceptions( - () -> { - Minutes period = Minutes.minutesBetween(countSince, countTo); - if (period.isLessThan(Minutes.ONE)) { - return 0L; - } - - GetMetricStatisticsRequest request = - createMetricStatisticsRequest(streamName, countSince, countTo, period); - - long totalSizeInBytes = 0; - GetMetricStatisticsResult result = cloudWatch.getMetricStatistics(request); - for (Datapoint point : result.getDatapoints()) { - totalSizeInBytes += point.getSum().longValue(); - } - return totalSizeInBytes; - }); - } - - GetMetricStatisticsRequest createMetricStatisticsRequest( - String streamName, Instant countSince, Instant countTo, Minutes period) { - return new GetMetricStatisticsRequest() - .withNamespace(KINESIS_NAMESPACE) - .withMetricName(INCOMING_RECORDS_METRIC) - .withPeriod(period.getMinutes() * PERIOD_GRANULARITY_IN_SECONDS) - .withStartTime(countSince.toDate()) - .withEndTime(countTo.toDate()) - .withStatistics(Collections.singletonList(SUM_STATISTIC)) - .withDimensions( - Collections.singletonList( - new Dimension().withName(STREAM_NAME_DIMENSION).withValue(streamName))); - } - - /** - * Wraps Amazon specific exceptions into more friendly format. - * - * @throws TransientKinesisException - in case of recoverable situation, i.e. the request rate is - * too high, Kinesis remote service failed, network issue, etc. - * @throws ExpiredIteratorException - if iterator needs to be refreshed - * @throws RuntimeException - in all other cases - */ - private T wrapExceptions(Callable callable) throws TransientKinesisException { - try { - return callable.call(); - } catch (ExpiredIteratorException e) { - throw e; - } catch (LimitExceededException | ProvisionedThroughputExceededException e) { - throw new KinesisClientThrottledException( - "Too many requests to Kinesis. Wait some time and retry.", e); - } catch (AmazonServiceException e) { - if (e.getErrorType() == AmazonServiceException.ErrorType.Service) { - throw new TransientKinesisException("Kinesis backend failed. Wait some time and retry.", e); - } - throw new RuntimeException("Kinesis client side failure", e); - } catch (AmazonClientException e) { - if (e.isRetryable()) { - throw new TransientKinesisException("Retryable client failure", e); - } - throw new RuntimeException("Not retryable client failure", e); - } catch (Exception e) { - throw new RuntimeException("Unknown kinesis failure, when trying to reach kinesis", e); - } - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java deleted file mode 100644 index 6fde16d7f3b9..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import java.io.Serializable; -import java.util.Objects; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Instant; - -/** - * Denotes a point at which the reader should start reading from a Kinesis stream. It can be - * expressed either as an {@link InitialPositionInStream} enum constant or a timestamp, in which - * case the reader will start reading at the specified point in time. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class StartingPoint implements Serializable { - - private final InitialPositionInStream position; - private final Instant timestamp; - - public StartingPoint(InitialPositionInStream position) { - this.position = checkNotNull(position, "position"); - this.timestamp = null; - } - - public StartingPoint(Instant timestamp) { - this.timestamp = checkNotNull(timestamp, "timestamp"); - this.position = InitialPositionInStream.AT_TIMESTAMP; - } - - public InitialPositionInStream getPosition() { - return position; - } - - public String getPositionName() { - return position.name(); - } - - public Instant getTimestamp() { - return timestamp; - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - StartingPoint that = (StartingPoint) o; - return position == that.position && Objects.equals(timestamp, that.timestamp); - } - - @Override - public int hashCode() { - return Objects.hash(position, timestamp); - } - - @Override - public String toString() { - if (timestamp == null) { - return position.toString(); - } else { - return "Starting at timestamp " + timestamp; - } - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StaticCheckpointGenerator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StaticCheckpointGenerator.java deleted file mode 100644 index 9364f98eccea..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StaticCheckpointGenerator.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -/** Always returns the same instance of checkpoint. */ -class StaticCheckpointGenerator implements CheckpointGenerator { - - private final KinesisReaderCheckpoint checkpoint; - - public StaticCheckpointGenerator(KinesisReaderCheckpoint checkpoint) { - checkNotNull(checkpoint, "checkpoint"); - this.checkpoint = checkpoint; - } - - @Override - public KinesisReaderCheckpoint generate(SimplifiedKinesisClient client) { - return checkpoint; - } - - @Override - public String toString() { - return checkpoint.toString(); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/TransientKinesisException.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/TransientKinesisException.java deleted file mode 100644 index 876acf85c998..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/TransientKinesisException.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import com.amazonaws.AmazonClientException; - -/** A transient exception thrown by Kinesis. */ -class TransientKinesisException extends Exception { - - public TransientKinesisException(String s, AmazonClientException e) { - super(s, e); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkParameters.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkParameters.java deleted file mode 100644 index f604dc9dc11b..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkParameters.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import com.google.auto.value.AutoValue; -import java.io.Serializable; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.joda.time.Duration; -import org.joda.time.Instant; - -/** {@code WatermarkParameters} contains the parameters used for watermark computation. */ -@AutoValue -public abstract class WatermarkParameters implements Serializable { - - private static final SerializableFunction ARRIVAL_TIME_FN = - KinesisRecord::getApproximateArrivalTimestamp; - private static final Duration STANDARD_WATERMARK_IDLE_DURATION_THRESHOLD = - Duration.standardMinutes(2); - - abstract Instant getCurrentWatermark(); - - abstract Instant getEventTime(); - - abstract Instant getLastUpdateTime(); - - abstract SerializableFunction getTimestampFn(); - - abstract Duration getWatermarkIdleDurationThreshold(); - - public abstract Builder toBuilder(); - - public static Builder builder() { - return new AutoValue_WatermarkParameters.Builder() - .setCurrentWatermark(BoundedWindow.TIMESTAMP_MIN_VALUE) - .setEventTime(BoundedWindow.TIMESTAMP_MIN_VALUE) - .setTimestampFn(ARRIVAL_TIME_FN) - .setLastUpdateTime(Instant.now()) - .setWatermarkIdleDurationThreshold(STANDARD_WATERMARK_IDLE_DURATION_THRESHOLD); - } - - @AutoValue.Builder - abstract static class Builder { - abstract Builder setCurrentWatermark(Instant currentWatermark); - - abstract Builder setEventTime(Instant eventTime); - - abstract Builder setLastUpdateTime(Instant now); - - abstract Builder setWatermarkIdleDurationThreshold(Duration watermarkIdleDurationThreshold); - - abstract Builder setTimestampFn(SerializableFunction timestampFn); - - abstract WatermarkParameters build(); - } - - public static WatermarkParameters create() { - return builder().build(); - } - - /** - * Specify the {@code SerializableFunction} to extract the event time from a {@code - * KinesisRecord}. The default event timestamp is the arrival timestamp of the record. - * - * @param timestampFn Serializable function to extract the timestamp from a record. - */ - public WatermarkParameters withTimestampFn( - SerializableFunction timestampFn) { - checkArgument(timestampFn != null, "timestampFn function is null"); - return toBuilder().setTimestampFn(timestampFn).build(); - } - - /** - * Specify the watermark idle duration to consider before advancing the watermark. The default - * watermark idle duration threshold is 2 minutes. - */ - public WatermarkParameters withWatermarkIdleDurationThreshold(Duration idleDurationThreshold) { - checkArgument(idleDurationThreshold != null, "watermark idle duration threshold is null"); - return toBuilder().setWatermarkIdleDurationThreshold(idleDurationThreshold).build(); - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicy.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicy.java deleted file mode 100644 index 69ac45f0a7dc..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicy.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.io.Serializable; -import org.joda.time.Instant; - -/** Implement this interface to define a custom watermark calculation heuristic. */ -public interface WatermarkPolicy extends Serializable { - - Instant getWatermark(); - - void update(KinesisRecord record); -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicyFactory.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicyFactory.java deleted file mode 100644 index 62de2fe16a5e..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicyFactory.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.io.Serializable; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Ordering; -import org.joda.time.Duration; -import org.joda.time.Instant; - -/** - * Implement this interface to create a {@code WatermarkPolicy}. Used by the {@code - * ShardRecordsIterator} to create a watermark policy for every shard. - */ -public interface WatermarkPolicyFactory extends Serializable { - - WatermarkPolicy createWatermarkPolicy(); - - /** Returns an ArrivalTimeWatermarkPolicy. */ - static WatermarkPolicyFactory withArrivalTimePolicy() { - return ArrivalTimeWatermarkPolicy::new; - } - - /** - * Returns an ArrivalTimeWatermarkPolicy. - * - * @param watermarkIdleDurationThreshold watermark idle duration threshold. - */ - static WatermarkPolicyFactory withArrivalTimePolicy(Duration watermarkIdleDurationThreshold) { - return () -> new ArrivalTimeWatermarkPolicy(watermarkIdleDurationThreshold); - } - - /** Returns an ProcessingTimeWatermarkPolicy. */ - static WatermarkPolicyFactory withProcessingTimePolicy() { - return ProcessingTimeWatermarkPolicy::new; - } - - /** - * Returns an custom WatermarkPolicyFactory. - * - * @param watermarkParameters Watermark parameters (timestamp extractor, watermark lag) for the - * policy. - */ - static WatermarkPolicyFactory withCustomWatermarkPolicy(WatermarkParameters watermarkParameters) { - return () -> new CustomWatermarkPolicy(watermarkParameters); - } - - /** - * ArrivalTimeWatermarkPolicy uses {@link CustomWatermarkPolicy} for watermark computation. It - * uses the arrival time of the record as the event time for watermark calculations. - */ - class ArrivalTimeWatermarkPolicy implements WatermarkPolicy { - private final CustomWatermarkPolicy watermarkPolicy; - - ArrivalTimeWatermarkPolicy() { - this.watermarkPolicy = - new CustomWatermarkPolicy( - WatermarkParameters.create() - .withTimestampFn(KinesisRecord::getApproximateArrivalTimestamp)); - } - - ArrivalTimeWatermarkPolicy(Duration idleDurationThreshold) { - WatermarkParameters watermarkParameters = - WatermarkParameters.create() - .withTimestampFn(KinesisRecord::getApproximateArrivalTimestamp) - .withWatermarkIdleDurationThreshold(idleDurationThreshold); - this.watermarkPolicy = new CustomWatermarkPolicy(watermarkParameters); - } - - @Override - public Instant getWatermark() { - return watermarkPolicy.getWatermark(); - } - - @Override - public void update(KinesisRecord record) { - watermarkPolicy.update(record); - } - } - - /** - * CustomWatermarkPolicy uses parameters defined in {@link WatermarkParameters} to compute - * watermarks. This can be used as a standard heuristic to compute watermarks. Used by {@link - * ArrivalTimeWatermarkPolicy}. - */ - class CustomWatermarkPolicy implements WatermarkPolicy { - private WatermarkParameters watermarkParameters; - - CustomWatermarkPolicy(WatermarkParameters watermarkParameters) { - this.watermarkParameters = watermarkParameters; - } - - @Override - public Instant getWatermark() { - Instant now = Instant.now(); - Instant watermarkIdleThreshold = - now.minus(watermarkParameters.getWatermarkIdleDurationThreshold()); - - Instant newWatermark = - watermarkParameters.getLastUpdateTime().isBefore(watermarkIdleThreshold) - ? watermarkIdleThreshold - : watermarkParameters.getEventTime(); - - if (newWatermark.isAfter(watermarkParameters.getCurrentWatermark())) { - watermarkParameters = - watermarkParameters.toBuilder().setCurrentWatermark(newWatermark).build(); - } - return watermarkParameters.getCurrentWatermark(); - } - - @Override - public void update(KinesisRecord record) { - watermarkParameters = - watermarkParameters - .toBuilder() - .setEventTime( - Ordering.natural() - .max( - watermarkParameters.getEventTime(), - watermarkParameters.getTimestampFn().apply(record))) - .setLastUpdateTime(Instant.now()) - .build(); - } - } - - /** Watermark policy where the processing time is used as the event time. */ - class ProcessingTimeWatermarkPolicy implements WatermarkPolicy { - @Override - public Instant getWatermark() { - return Instant.now(); - } - - @Override - public void update(KinesisRecord record) { - // do nothing - } - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/AwsModule.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/AwsModule.java deleted file mode 100644 index d8396d5da924..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/AwsModule.java +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis.serde; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.BasicSessionCredentials; -import com.amazonaws.auth.ClasspathPropertiesFileCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.auth.EC2ContainerCredentialsProviderWrapper; -import com.amazonaws.auth.EnvironmentVariableCredentialsProvider; -import com.amazonaws.auth.PropertiesFileCredentialsProvider; -import com.amazonaws.auth.SystemPropertiesCredentialsProvider; -import com.amazonaws.auth.profile.ProfileCredentialsProvider; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.Module; -import com.fasterxml.jackson.databind.SerializerProvider; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.fasterxml.jackson.databind.jsontype.TypeDeserializer; -import com.fasterxml.jackson.databind.jsontype.TypeSerializer; -import com.fasterxml.jackson.databind.module.SimpleModule; -import java.io.IOException; -import java.util.Map; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.commons.lang3.reflect.FieldUtils; - -/** - * A Jackson {@link Module} that registers a {@link JsonSerializer} and {@link JsonDeserializer} for - * {@link AWSCredentialsProvider} and some subclasses. The serialized form is a JSON map. - * - *

Note: This module is a stripped down version of {@link AwsModule} in 'amazon-web-services' - * excluding support for STS. - */ -class AwsModule extends SimpleModule { - - private static final String AWS_ACCESS_KEY_ID = "awsAccessKeyId"; - private static final String AWS_SECRET_KEY = "awsSecretKey"; - private static final String SESSION_TOKEN = "sessionToken"; - private static final String CREDENTIALS_FILE_PATH = "credentialsFilePath"; - - @SuppressWarnings({"nullness"}) - AwsModule() { - super("AwsModule"); - setMixInAnnotation(AWSCredentialsProvider.class, AWSCredentialsProviderMixin.class); - } - - /** A mixin to add Jackson annotations to {@link AWSCredentialsProvider}. */ - @JsonDeserialize(using = AWSCredentialsProviderDeserializer.class) - @JsonSerialize(using = AWSCredentialsProviderSerializer.class) - @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY) - private static class AWSCredentialsProviderMixin {} - - private static class AWSCredentialsProviderDeserializer - extends JsonDeserializer { - - @Override - public AWSCredentialsProvider deserialize(JsonParser jsonParser, DeserializationContext context) - throws IOException { - return context.readValue(jsonParser, AWSCredentialsProvider.class); - } - - @Override - public AWSCredentialsProvider deserializeWithType( - JsonParser jsonParser, DeserializationContext context, TypeDeserializer typeDeserializer) - throws IOException { - Map asMap = - checkNotNull(jsonParser.readValueAs(new TypeReference>() {})); - - String typeNameKey = typeDeserializer.getPropertyName(); - String typeName = getNotNull(asMap, typeNameKey, "unknown"); - if (hasName(AWSStaticCredentialsProvider.class, typeName)) { - boolean isSession = asMap.containsKey(SESSION_TOKEN); - if (isSession) { - return new AWSStaticCredentialsProvider( - new BasicSessionCredentials( - getNotNull(asMap, AWS_ACCESS_KEY_ID, typeName), - getNotNull(asMap, AWS_SECRET_KEY, typeName), - getNotNull(asMap, SESSION_TOKEN, typeName))); - } else { - return new AWSStaticCredentialsProvider( - new BasicAWSCredentials( - getNotNull(asMap, AWS_ACCESS_KEY_ID, typeName), - getNotNull(asMap, AWS_SECRET_KEY, typeName))); - } - } else if (hasName(PropertiesFileCredentialsProvider.class, typeName)) { - return new PropertiesFileCredentialsProvider( - getNotNull(asMap, CREDENTIALS_FILE_PATH, typeName)); - } else if (hasName(ClasspathPropertiesFileCredentialsProvider.class, typeName)) { - return new ClasspathPropertiesFileCredentialsProvider( - getNotNull(asMap, CREDENTIALS_FILE_PATH, typeName)); - } else if (hasName(DefaultAWSCredentialsProviderChain.class, typeName)) { - return DefaultAWSCredentialsProviderChain.getInstance(); - } else if (hasName(EnvironmentVariableCredentialsProvider.class, typeName)) { - return new EnvironmentVariableCredentialsProvider(); - } else if (hasName(SystemPropertiesCredentialsProvider.class, typeName)) { - return new SystemPropertiesCredentialsProvider(); - } else if (hasName(ProfileCredentialsProvider.class, typeName)) { - return new ProfileCredentialsProvider(); - } else if (hasName(EC2ContainerCredentialsProviderWrapper.class, typeName)) { - return new EC2ContainerCredentialsProviderWrapper(); - } else { - throw new IOException( - String.format("AWS credential provider type '%s' is not supported", typeName)); - } - } - - @SuppressWarnings({"nullness"}) - private String getNotNull(Map map, String key, String typeName) { - return checkNotNull( - map.get(key), "AWS credentials provider type '%s' is missing '%s'", typeName, key); - } - - private boolean hasName(Class clazz, String typeName) { - return typeName.equals(clazz.getSimpleName()); - } - } - - private static class AWSCredentialsProviderSerializer - extends JsonSerializer { - // These providers are singletons, so don't require any serialization, other than type. - private static final ImmutableSet SINGLETON_CREDENTIAL_PROVIDERS = - ImmutableSet.of( - DefaultAWSCredentialsProviderChain.class, - EnvironmentVariableCredentialsProvider.class, - SystemPropertiesCredentialsProvider.class, - ProfileCredentialsProvider.class, - EC2ContainerCredentialsProviderWrapper.class); - - @Override - public void serialize( - AWSCredentialsProvider credentialsProvider, - JsonGenerator jsonGenerator, - SerializerProvider serializers) - throws IOException { - serializers.defaultSerializeValue(credentialsProvider, jsonGenerator); - } - - @Override - public void serializeWithType( - AWSCredentialsProvider credentialsProvider, - JsonGenerator jsonGenerator, - SerializerProvider serializers, - TypeSerializer typeSerializer) - throws IOException { - // BEAM-11958 Use deprecated Jackson APIs to be compatible with older versions of jackson - typeSerializer.writeTypePrefixForObject(credentialsProvider, jsonGenerator); - - Class providerClass = credentialsProvider.getClass(); - if (providerClass.equals(AWSStaticCredentialsProvider.class)) { - AWSCredentials credentials = credentialsProvider.getCredentials(); - if (credentials.getClass().equals(BasicSessionCredentials.class)) { - BasicSessionCredentials sessionCredentials = (BasicSessionCredentials) credentials; - jsonGenerator.writeStringField(AWS_ACCESS_KEY_ID, sessionCredentials.getAWSAccessKeyId()); - jsonGenerator.writeStringField(AWS_SECRET_KEY, sessionCredentials.getAWSSecretKey()); - jsonGenerator.writeStringField(SESSION_TOKEN, sessionCredentials.getSessionToken()); - } else { - jsonGenerator.writeStringField(AWS_ACCESS_KEY_ID, credentials.getAWSAccessKeyId()); - jsonGenerator.writeStringField(AWS_SECRET_KEY, credentials.getAWSSecretKey()); - } - } else if (providerClass.equals(PropertiesFileCredentialsProvider.class)) { - jsonGenerator.writeStringField( - CREDENTIALS_FILE_PATH, readProviderField(credentialsProvider, CREDENTIALS_FILE_PATH)); - } else if (providerClass.equals(ClasspathPropertiesFileCredentialsProvider.class)) { - jsonGenerator.writeStringField( - CREDENTIALS_FILE_PATH, readProviderField(credentialsProvider, CREDENTIALS_FILE_PATH)); - } else if (!SINGLETON_CREDENTIAL_PROVIDERS.contains(credentialsProvider.getClass())) { - throw new IllegalArgumentException( - "Unsupported AWS credentials provider type " + credentialsProvider.getClass()); - } - // BEAM-11958 Use deprecated Jackson APIs to be compatible with older versions of jackson - typeSerializer.writeTypeSuffixForObject(credentialsProvider, jsonGenerator); - } - - private String readProviderField(AWSCredentialsProvider provider, String fieldName) - throws IOException { - try { - return (String) checkNotNull(FieldUtils.readField(provider, fieldName, true)); - } catch (NullPointerException | IllegalArgumentException | IllegalAccessException e) { - throw new IOException( - String.format( - "Failed to access private field '%s' of AWS credential provider type '%s' with reflection", - fieldName, provider.getClass().getSimpleName()), - e); - } - } - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/AwsSerializableUtils.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/AwsSerializableUtils.java deleted file mode 100644 index 37f7b4d65b46..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/AwsSerializableUtils.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis.serde; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import java.io.IOException; - -/** Utilities for working with AWS Serializables. */ -public class AwsSerializableUtils { - - public static String serialize(AWSCredentialsProvider awsCredentialsProvider) { - ObjectMapper om = new ObjectMapper(); - om.registerModule(new AwsModule()); - try { - return om.writeValueAsString(awsCredentialsProvider); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException("AwsCredentialsProvider can not be serialized to Json", e); - } - } - - public static AWSCredentialsProvider deserialize(String awsCredentialsProviderSerialized) { - ObjectMapper om = new ObjectMapper(); - om.registerModule(new AwsModule()); - try { - return om.readValue(awsCredentialsProviderSerialized, AWSCredentialsProvider.class); - } catch (IOException e) { - throw new IllegalArgumentException( - "AwsCredentialsProvider can not be deserialized from Json", e); - } - } -} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/package-info.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/package-info.java deleted file mode 100644 index 4384814b0818..000000000000 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/serde/package-info.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** Defines serializers / deserializers for AWS. */ -package org.apache.beam.sdk.io.kinesis.serde; diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java deleted file mode 100644 index 704a5ab07ba9..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java +++ /dev/null @@ -1,504 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static java.lang.Integer.parseInt; -import static java.lang.Math.min; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.transform; -import static org.apache.commons.lang.builder.HashCodeBuilder.reflectionHashCode; - -import com.amazonaws.AmazonWebServiceRequest; -import com.amazonaws.ResponseMetadata; -import com.amazonaws.http.HttpResponse; -import com.amazonaws.http.SdkHttpMetadata; -import com.amazonaws.regions.Region; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.model.AddTagsToStreamRequest; -import com.amazonaws.services.kinesis.model.AddTagsToStreamResult; -import com.amazonaws.services.kinesis.model.CreateStreamRequest; -import com.amazonaws.services.kinesis.model.CreateStreamResult; -import com.amazonaws.services.kinesis.model.DecreaseStreamRetentionPeriodRequest; -import com.amazonaws.services.kinesis.model.DecreaseStreamRetentionPeriodResult; -import com.amazonaws.services.kinesis.model.DeleteStreamRequest; -import com.amazonaws.services.kinesis.model.DeleteStreamResult; -import com.amazonaws.services.kinesis.model.DeregisterStreamConsumerRequest; -import com.amazonaws.services.kinesis.model.DeregisterStreamConsumerResult; -import com.amazonaws.services.kinesis.model.DescribeLimitsRequest; -import com.amazonaws.services.kinesis.model.DescribeLimitsResult; -import com.amazonaws.services.kinesis.model.DescribeStreamConsumerRequest; -import com.amazonaws.services.kinesis.model.DescribeStreamConsumerResult; -import com.amazonaws.services.kinesis.model.DescribeStreamRequest; -import com.amazonaws.services.kinesis.model.DescribeStreamResult; -import com.amazonaws.services.kinesis.model.DescribeStreamSummaryRequest; -import com.amazonaws.services.kinesis.model.DescribeStreamSummaryResult; -import com.amazonaws.services.kinesis.model.DisableEnhancedMonitoringRequest; -import com.amazonaws.services.kinesis.model.DisableEnhancedMonitoringResult; -import com.amazonaws.services.kinesis.model.EnableEnhancedMonitoringRequest; -import com.amazonaws.services.kinesis.model.EnableEnhancedMonitoringResult; -import com.amazonaws.services.kinesis.model.GetRecordsRequest; -import com.amazonaws.services.kinesis.model.GetRecordsResult; -import com.amazonaws.services.kinesis.model.GetShardIteratorRequest; -import com.amazonaws.services.kinesis.model.GetShardIteratorResult; -import com.amazonaws.services.kinesis.model.IncreaseStreamRetentionPeriodRequest; -import com.amazonaws.services.kinesis.model.IncreaseStreamRetentionPeriodResult; -import com.amazonaws.services.kinesis.model.LimitExceededException; -import com.amazonaws.services.kinesis.model.ListShardsRequest; -import com.amazonaws.services.kinesis.model.ListShardsResult; -import com.amazonaws.services.kinesis.model.ListStreamConsumersRequest; -import com.amazonaws.services.kinesis.model.ListStreamConsumersResult; -import com.amazonaws.services.kinesis.model.ListStreamsRequest; -import com.amazonaws.services.kinesis.model.ListStreamsResult; -import com.amazonaws.services.kinesis.model.ListTagsForStreamRequest; -import com.amazonaws.services.kinesis.model.ListTagsForStreamResult; -import com.amazonaws.services.kinesis.model.MergeShardsRequest; -import com.amazonaws.services.kinesis.model.MergeShardsResult; -import com.amazonaws.services.kinesis.model.PutRecordRequest; -import com.amazonaws.services.kinesis.model.PutRecordResult; -import com.amazonaws.services.kinesis.model.PutRecordsRequest; -import com.amazonaws.services.kinesis.model.PutRecordsResult; -import com.amazonaws.services.kinesis.model.Record; -import com.amazonaws.services.kinesis.model.RegisterStreamConsumerRequest; -import com.amazonaws.services.kinesis.model.RegisterStreamConsumerResult; -import com.amazonaws.services.kinesis.model.RemoveTagsFromStreamRequest; -import com.amazonaws.services.kinesis.model.RemoveTagsFromStreamResult; -import com.amazonaws.services.kinesis.model.Shard; -import com.amazonaws.services.kinesis.model.ShardIteratorType; -import com.amazonaws.services.kinesis.model.SplitShardRequest; -import com.amazonaws.services.kinesis.model.SplitShardResult; -import com.amazonaws.services.kinesis.model.StartStreamEncryptionRequest; -import com.amazonaws.services.kinesis.model.StartStreamEncryptionResult; -import com.amazonaws.services.kinesis.model.StopStreamEncryptionRequest; -import com.amazonaws.services.kinesis.model.StopStreamEncryptionResult; -import com.amazonaws.services.kinesis.model.UpdateShardCountRequest; -import com.amazonaws.services.kinesis.model.UpdateShardCountResult; -import com.amazonaws.services.kinesis.model.UpdateStreamModeRequest; -import com.amazonaws.services.kinesis.model.UpdateStreamModeResult; -import com.amazonaws.services.kinesis.producer.IKinesisProducer; -import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; -import com.amazonaws.services.kinesis.waiters.AmazonKinesisWaiters; -import java.io.Serializable; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter; -import org.apache.commons.lang.builder.EqualsBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Instant; -import org.mockito.Mockito; - -/** Mock implementation of {@link AmazonKinesis} for testing. */ -class AmazonKinesisMock implements AmazonKinesis { - - static class TestData implements Serializable { - - private final String data; - private final Instant arrivalTimestamp; - private final String sequenceNumber; - - public TestData(KinesisRecord record) { - this( - new String(record.getData().array(), StandardCharsets.UTF_8), - record.getApproximateArrivalTimestamp(), - record.getSequenceNumber()); - } - - public TestData(String data, Instant arrivalTimestamp, String sequenceNumber) { - this.data = data; - this.arrivalTimestamp = arrivalTimestamp; - this.sequenceNumber = sequenceNumber; - } - - public Record convertToRecord() { - return new Record() - .withApproximateArrivalTimestamp(arrivalTimestamp.toDate()) - .withData(ByteBuffer.wrap(data.getBytes(StandardCharsets.UTF_8))) - .withSequenceNumber(sequenceNumber) - .withPartitionKey(""); - } - - @Override - public boolean equals(@Nullable Object obj) { - return EqualsBuilder.reflectionEquals(this, obj); - } - - @Override - public int hashCode() { - return reflectionHashCode(this); - } - - @Override - public String toString() { - return "TestData{" - + "data='" - + data - + '\'' - + ", arrivalTimestamp=" - + arrivalTimestamp - + ", sequenceNumber='" - + sequenceNumber - + '\'' - + '}'; - } - } - - static class Provider implements AWSClientsProvider { - - private final List> shardedData; - private final int numberOfRecordsPerGet; - - private boolean expectedListShardsLimitExceededException; - - public Provider(List> shardedData, int numberOfRecordsPerGet) { - this.shardedData = shardedData; - this.numberOfRecordsPerGet = numberOfRecordsPerGet; - } - - /** Simulate limit exceeded exception for ListShards. */ - public Provider withExpectedListShardsLimitExceededException() { - expectedListShardsLimitExceededException = true; - return this; - } - - @Override - public AmazonKinesis getKinesisClient() { - AmazonKinesisMock client = - new AmazonKinesisMock( - shardedData.stream() - .map(testData -> transform(testData, TestData::convertToRecord)) - .collect(Collectors.toList()), - numberOfRecordsPerGet); - if (expectedListShardsLimitExceededException) { - client = client.withExpectedListShardsLimitExceededException(); - } - return client; - } - - @Override - public AmazonCloudWatch getCloudWatchClient() { - return Mockito.mock(AmazonCloudWatch.class); - } - - @Override - public IKinesisProducer createKinesisProducer(KinesisProducerConfiguration config) { - throw new RuntimeException("Not implemented"); - } - } - - private final List> shardedData; - private final int numberOfRecordsPerGet; - - private boolean expectedListShardsLimitExceededException; - - public AmazonKinesisMock(List> shardedData, int numberOfRecordsPerGet) { - this.shardedData = shardedData; - this.numberOfRecordsPerGet = numberOfRecordsPerGet; - } - - public AmazonKinesisMock withExpectedListShardsLimitExceededException() { - this.expectedListShardsLimitExceededException = true; - return this; - } - - @Override - public GetRecordsResult getRecords(GetRecordsRequest getRecordsRequest) { - List shardIteratorParts = - Splitter.on(':').splitToList(getRecordsRequest.getShardIterator()); - int shardId = parseInt(shardIteratorParts.get(0)); - int startingRecord = parseInt(shardIteratorParts.get(1)); - List shardData = shardedData.get(shardId); - - int toIndex = min(startingRecord + numberOfRecordsPerGet, shardData.size()); - int fromIndex = min(startingRecord, toIndex); - return new GetRecordsResult() - .withRecords(shardData.subList(fromIndex, toIndex)) - .withNextShardIterator(String.format("%s:%s", shardId, toIndex)) - .withMillisBehindLatest(0L); - } - - @Override - public GetShardIteratorResult getShardIterator(GetShardIteratorRequest getShardIteratorRequest) { - ShardIteratorType shardIteratorType = - ShardIteratorType.fromValue(getShardIteratorRequest.getShardIteratorType()); - - String shardIterator; - if (shardIteratorType == ShardIteratorType.TRIM_HORIZON) { - shardIterator = String.format("%s:%s", getShardIteratorRequest.getShardId(), 0); - } else { - throw new RuntimeException("Not implemented"); - } - - return new GetShardIteratorResult().withShardIterator(shardIterator); - } - - @Override - public DescribeStreamResult describeStream(String streamName, String exclusiveStartShardId) { - throw new RuntimeException("Not implemented"); - } - - @Override - public void setEndpoint(String endpoint) {} - - @Override - public void setRegion(Region region) {} - - @Override - public AddTagsToStreamResult addTagsToStream(AddTagsToStreamRequest addTagsToStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public CreateStreamResult createStream(CreateStreamRequest createStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public CreateStreamResult createStream(String streamName, Integer shardCount) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DecreaseStreamRetentionPeriodResult decreaseStreamRetentionPeriod( - DecreaseStreamRetentionPeriodRequest decreaseStreamRetentionPeriodRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DeleteStreamResult deleteStream(DeleteStreamRequest deleteStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DeleteStreamResult deleteStream(String streamName) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DeregisterStreamConsumerResult deregisterStreamConsumer( - DeregisterStreamConsumerRequest deregisterStreamConsumerRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DescribeLimitsResult describeLimits(DescribeLimitsRequest describeLimitsRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DescribeStreamResult describeStream(DescribeStreamRequest describeStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DescribeStreamResult describeStream(String streamName) { - return describeStream(streamName, null); - } - - @Override - public DescribeStreamResult describeStream( - String streamName, Integer limit, String exclusiveStartShardId) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DescribeStreamConsumerResult describeStreamConsumer( - DescribeStreamConsumerRequest describeStreamConsumerRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DescribeStreamSummaryResult describeStreamSummary( - DescribeStreamSummaryRequest describeStreamSummaryRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DisableEnhancedMonitoringResult disableEnhancedMonitoring( - DisableEnhancedMonitoringRequest disableEnhancedMonitoringRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public EnableEnhancedMonitoringResult enableEnhancedMonitoring( - EnableEnhancedMonitoringRequest enableEnhancedMonitoringRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public GetShardIteratorResult getShardIterator( - String streamName, String shardId, String shardIteratorType) { - throw new RuntimeException("Not implemented"); - } - - @Override - public GetShardIteratorResult getShardIterator( - String streamName, String shardId, String shardIteratorType, String startingSequenceNumber) { - throw new RuntimeException("Not implemented"); - } - - @Override - public IncreaseStreamRetentionPeriodResult increaseStreamRetentionPeriod( - IncreaseStreamRetentionPeriodRequest increaseStreamRetentionPeriodRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public ListShardsResult listShards(ListShardsRequest listShardsRequest) { - if (expectedListShardsLimitExceededException) { - throw new LimitExceededException("ListShards rate limit exceeded"); - } - - ListShardsResult result = new ListShardsResult(); - - List shards = - IntStream.range(0, shardedData.size()) - .boxed() - .map(i -> new Shard().withShardId(Integer.toString(i))) - .collect(Collectors.toList()); - result.setShards(shards); - - HttpResponse response = new HttpResponse(null, null); - response.setStatusCode(200); - result.setSdkHttpMetadata(SdkHttpMetadata.from(response)); - return result; - } - - @Override - public ListStreamConsumersResult listStreamConsumers( - ListStreamConsumersRequest listStreamConsumersRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public ListStreamsResult listStreams(ListStreamsRequest listStreamsRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public ListStreamsResult listStreams() { - throw new RuntimeException("Not implemented"); - } - - @Override - public ListStreamsResult listStreams(String exclusiveStartStreamName) { - throw new RuntimeException("Not implemented"); - } - - @Override - public ListStreamsResult listStreams(Integer limit, String exclusiveStartStreamName) { - throw new RuntimeException("Not implemented"); - } - - @Override - public ListTagsForStreamResult listTagsForStream( - ListTagsForStreamRequest listTagsForStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public MergeShardsResult mergeShards(MergeShardsRequest mergeShardsRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public MergeShardsResult mergeShards( - String streamName, String shardToMerge, String adjacentShardToMerge) { - throw new RuntimeException("Not implemented"); - } - - @Override - public PutRecordResult putRecord(PutRecordRequest putRecordRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public PutRecordResult putRecord(String streamName, ByteBuffer data, String partitionKey) { - throw new RuntimeException("Not implemented"); - } - - @Override - public PutRecordResult putRecord( - String streamName, ByteBuffer data, String partitionKey, String sequenceNumberForOrdering) { - throw new RuntimeException("Not implemented"); - } - - @Override - public PutRecordsResult putRecords(PutRecordsRequest putRecordsRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public RegisterStreamConsumerResult registerStreamConsumer( - RegisterStreamConsumerRequest registerStreamConsumerRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public RemoveTagsFromStreamResult removeTagsFromStream( - RemoveTagsFromStreamRequest removeTagsFromStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public SplitShardResult splitShard(SplitShardRequest splitShardRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public SplitShardResult splitShard( - String streamName, String shardToSplit, String newStartingHashKey) { - throw new RuntimeException("Not implemented"); - } - - @Override - public StartStreamEncryptionResult startStreamEncryption( - StartStreamEncryptionRequest startStreamEncryptionRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public StopStreamEncryptionResult stopStreamEncryption( - StopStreamEncryptionRequest stopStreamEncryptionRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public UpdateShardCountResult updateShardCount(UpdateShardCountRequest updateShardCountRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public UpdateStreamModeResult updateStreamMode(UpdateStreamModeRequest updateStreamModeRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public void shutdown() {} - - @Override - public ResponseMetadata getCachedResponseMetadata(AmazonWebServiceRequest request) { - throw new RuntimeException("Not implemented"); - } - - @Override - public AmazonKinesisWaiters waiters() { - throw new RuntimeException("Not implemented"); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/BasicKinesisClientProviderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/BasicKinesisClientProviderTest.java deleted file mode 100644 index 938dc9b6f8b6..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/BasicKinesisClientProviderTest.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.junit.Assert.assertEquals; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.regions.Regions; -import org.apache.beam.sdk.util.SerializableUtils; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests on {@link org.apache.beam.sdk.io.aws2.kinesis.BasicKinesisProvider}. */ -@RunWith(JUnit4.class) -public class BasicKinesisClientProviderTest { - private static final String ACCESS_KEY_ID = "ACCESS_KEY_ID"; - private static final String SECRET_ACCESS_KEY = "SECRET_ACCESS_KEY"; - - @Test - public void testSerialization() { - AWSCredentialsProvider awsCredentialsProvider = - new AWSStaticCredentialsProvider(new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY)); - - BasicKinesisProvider kinesisProvider = - new BasicKinesisProvider(awsCredentialsProvider, Regions.AP_EAST_1, null, true); - - byte[] serializedBytes = SerializableUtils.serializeToByteArray(kinesisProvider); - - BasicKinesisProvider kinesisProviderDeserialized = - (BasicKinesisProvider) - SerializableUtils.deserializeFromByteArray(serializedBytes, "Basic Kinesis Provider"); - - assertEquals(kinesisProvider, kinesisProviderDeserialized); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java deleted file mode 100644 index 00e6b9334025..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import com.google.common.testing.EqualsTester; -import java.util.NoSuchElementException; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests {@link CustomOptional}. */ -@RunWith(JUnit4.class) -public class CustomOptionalTest { - - @Test(expected = NoSuchElementException.class) - public void absentThrowsNoSuchElementExceptionOnGet() { - CustomOptional.absent().get(); - } - - @Test - public void testEqualsAndHashCode() { - new EqualsTester() - .addEqualityGroup(CustomOptional.absent(), CustomOptional.absent()) - .addEqualityGroup(CustomOptional.of(3), CustomOptional.of(3)) - .addEqualityGroup(CustomOptional.of(11)) - .addEqualityGroup(CustomOptional.of("3")) - .testEquals(); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java deleted file mode 100644 index 1426f3b52197..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.amazonaws.services.kinesis.model.Shard; -import java.util.List; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; - -/** * */ -@RunWith(MockitoJUnitRunner.class) -public class DynamicCheckpointGeneratorTest { - - @Mock private SimplifiedKinesisClient kinesisClient; - @Mock private Shard shard1, shard2, shard3; - - @Test - public void shouldMapAllShardsToCheckpoints() throws Exception { - when(shard1.getShardId()).thenReturn("shard-01"); - when(shard2.getShardId()).thenReturn("shard-02"); - when(shard3.getShardId()).thenReturn("shard-03"); - List shards = ImmutableList.of(shard1, shard2, shard3); - String streamName = "stream"; - StartingPoint startingPoint = new StartingPoint(InitialPositionInStream.LATEST); - when(kinesisClient.listShardsAtPoint(streamName, startingPoint)).thenReturn(shards); - DynamicCheckpointGenerator underTest = - new DynamicCheckpointGenerator(streamName, startingPoint); - - KinesisReaderCheckpoint checkpoint = underTest.generate(kinesisClient); - - assertThat(checkpoint).hasSize(3); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOIT.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOIT.java deleted file mode 100644 index b2ec825f7d85..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOIT.java +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import java.io.Serializable; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import org.apache.beam.sdk.io.GenerateSequence; -import org.apache.beam.sdk.io.common.HashingFn; -import org.apache.beam.sdk.io.common.TestRow; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; -import org.junit.function.ThrowingRunnable; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.slf4j.LoggerFactory; -import org.testcontainers.containers.localstack.LocalStackContainer; -import org.testcontainers.containers.localstack.LocalStackContainer.Service; -import org.testcontainers.utility.DockerImageName; - -/** - * Integration test, that writes and reads data to and from real Kinesis. You need to provide {@link - * KinesisTestOptions} in order to run this if you want to test it with production setup. By default - * when no options are provided an instance of localstack is used. - */ -@RunWith(JUnit4.class) -public class KinesisIOIT implements Serializable { - private static final String LOCALSTACK_VERSION = "0.12.18"; - - @Rule public TestPipeline pipelineWrite = TestPipeline.create(); - @Rule public TestPipeline pipelineRead = TestPipeline.create(); - - // Will be run in reverse order - private static final List teardownTasks = new ArrayList<>(); - - private static KinesisTestOptions options; - - private static Instant now = Instant.now(); - - @BeforeClass - public static void setup() throws Exception { - PipelineOptionsFactory.register(KinesisTestOptions.class); - options = TestPipeline.testingPipelineOptions().as(KinesisTestOptions.class); - - if (options.getUseLocalstack()) { - setupLocalstack(); - } - if (options.getCreateStream()) { - AmazonKinesis kinesisClient = createKinesisClient(); - teardownTasks.add(kinesisClient::shutdown); - - createStream(kinesisClient); - teardownTasks.add(() -> deleteStream(kinesisClient)); - } - } - - @AfterClass - public static void teardown() { - Lists.reverse(teardownTasks).forEach(KinesisIOIT::safeRun); - teardownTasks.clear(); - } - - /** Test which write and then read data for a Kinesis stream. */ - @Test - public void testWriteThenRead() { - runWrite(); - runRead(); - } - - /** Write test dataset into Kinesis stream. */ - private void runWrite() { - pipelineWrite - .apply("Generate Sequence", GenerateSequence.from(0).to(options.getNumberOfRecords())) - .apply("Prepare TestRows", ParDo.of(new TestRow.DeterministicallyConstructTestRowFn())) - .apply("Prepare Kinesis input records", ParDo.of(new ConvertToBytes())) - .apply( - "Write to Kinesis", - KinesisIO.write() - .withStreamName(options.getAwsKinesisStream()) - .withPartitioner(new RandomPartitioner()) - .withAWSClientsProvider( - options.getAwsAccessKey(), - options.getAwsSecretKey(), - Regions.fromName(options.getAwsKinesisRegion()), - options.getAwsServiceEndpoint(), - options.getAwsVerifyCertificate())); - - pipelineWrite.run().waitUntilFinish(); - } - - /** Read test dataset from Kinesis stream. */ - private void runRead() { - PCollection output = - pipelineRead.apply( - KinesisIO.read() - .withStreamName(options.getAwsKinesisStream()) - .withAWSClientsProvider( - options.getAwsAccessKey(), - options.getAwsSecretKey(), - Regions.fromName(options.getAwsKinesisRegion()), - options.getAwsServiceEndpoint(), - options.getAwsVerifyCertificate()) - .withMaxNumRecords(options.getNumberOfRecords()) - // to prevent endless running in case of error - .withMaxReadTime(Duration.standardMinutes(10L)) - .withInitialPositionInStream(InitialPositionInStream.AT_TIMESTAMP) - .withInitialTimestampInStream(now) - .withRequestRecordsLimit(1000)); - - PAssert.thatSingleton(output.apply("Count All", Count.globally())) - .isEqualTo((long) options.getNumberOfRecords()); - - PCollection consolidatedHashcode = - output - .apply(ParDo.of(new ExtractDataValues())) - .apply("Hash row contents", Combine.globally(new HashingFn()).withoutDefaults()); - - PAssert.that(consolidatedHashcode) - .containsInAnyOrder(TestRow.getExpectedHashForRowCount(options.getNumberOfRecords())); - - pipelineRead.run().waitUntilFinish(); - } - - /** Necessary setup for localstack environment. */ - private static void setupLocalstack() { - // For some unclear reason localstack requires a timestamp in seconds - now = Instant.ofEpochMilli(Long.divideUnsigned(now.getMillis(), 1000L)); - - LocalStackContainer kinesisContainer = - new LocalStackContainer( - DockerImageName.parse("localstack/localstack").withTag(LOCALSTACK_VERSION)) - .withServices(Service.KINESIS) - .withEnv("USE_SSL", "true") - .withStartupAttempts(3); - - kinesisContainer.start(); - teardownTasks.add(() -> kinesisContainer.stop()); - - options.setAwsServiceEndpoint(kinesisContainer.getEndpointOverride(Service.KINESIS).toString()); - options.setAwsKinesisRegion(kinesisContainer.getRegion()); - options.setAwsAccessKey(kinesisContainer.getAccessKey()); - options.setAwsSecretKey(kinesisContainer.getSecretKey()); - options.setAwsVerifyCertificate(false); - options.setCreateStream(true); - } - - private static AmazonKinesis createKinesisClient() { - AWSCredentials credentials = - new BasicAWSCredentials(options.getAwsAccessKey(), options.getAwsSecretKey()); - AmazonKinesisClientBuilder clientBuilder = - AmazonKinesisClientBuilder.standard() - .withCredentials(new AWSStaticCredentialsProvider(credentials)); - - if (options.getAwsServiceEndpoint() != null) { - clientBuilder.setEndpointConfiguration( - new AwsClientBuilder.EndpointConfiguration( - options.getAwsServiceEndpoint(), options.getAwsKinesisRegion())); - } else { - clientBuilder.setRegion(options.getAwsKinesisRegion()); - } - - return clientBuilder.build(); - } - - private static void createStream(AmazonKinesis kinesisClient) throws Exception { - kinesisClient.createStream(options.getAwsKinesisStream(), options.getNumberOfShards()); - int attempts = 10; - for (int i = 0; i <= attempts; ++i) { - String streamStatus = - kinesisClient - .describeStream(options.getAwsKinesisStream()) - .getStreamDescription() - .getStreamStatus(); - if ("ACTIVE".equals(streamStatus)) { - return; - } - Thread.sleep(1000L); - } - throw new RuntimeException("Unable to initialize stream"); - } - - private static void deleteStream(AmazonKinesis kinesisClient) { - kinesisClient.deleteStream(options.getAwsKinesisStream()); - } - - private static void safeRun(ThrowingRunnable task) { - try { - task.run(); - } catch (Throwable e) { - LoggerFactory.getLogger(KinesisIOIT.class).warn("Cleanup task failed", e); - } - } - - /** Produces test rows. */ - private static class ConvertToBytes extends DoFn { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(String.valueOf(c.element().name()).getBytes(StandardCharsets.UTF_8)); - } - } - - /** Read rows from Table. */ - private static class ExtractDataValues extends DoFn { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(new String(c.element().getDataAsBytes(), StandardCharsets.UTF_8)); - } - } - - private static final class RandomPartitioner implements KinesisPartitioner { - @Override - public String getPartitionKey(byte[] value) { - Random rand = new Random(); - int n = rand.nextInt(options.getNumberOfShards()) + 1; - return String.valueOf(n); - } - - @Override - public String getExplicitHashKey(byte[] value) { - return null; - } - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOReadTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOReadTest.java deleted file mode 100644 index fdacc62bdb4a..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOReadTest.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.junit.Assert.assertEquals; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.regions.Regions; -import org.apache.beam.sdk.io.kinesis.KinesisIO.Read; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for non trivial builder variants of {@link KinesisIO#read}. */ -@RunWith(JUnit4.class) -public class KinesisIOReadTest { - private static final String ACCESS_KEY_ID = "ACCESS_KEY_ID"; - private static final String SECRET_ACCESS_KEY = "SECRET_ACCESS_KEY"; - private static final boolean VERIFICATION_DISABLED = false; - - @Test - public void testReadWithBasicCredentials() { - Regions region = Regions.US_EAST_1; - Read read = - KinesisIO.read().withAWSClientsProvider(ACCESS_KEY_ID, SECRET_ACCESS_KEY, region); - - assertEquals( - read.getAWSClientsProvider(), - new BasicKinesisProvider( - new AWSStaticCredentialsProvider( - new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY)), - region, - null, - true)); - } - - @Test - public void testReadWithCredentialsProvider() { - Regions region = Regions.US_EAST_1; - AWSCredentialsProvider credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); - - Read read = KinesisIO.read().withAWSClientsProvider(credentialsProvider, region); - - assertEquals( - read.getAWSClientsProvider(), - new BasicKinesisProvider(credentialsProvider, region, null, true)); - } - - @Test - public void testReadWithBasicCredentialsAndCustomEndpoint() { - String customEndpoint = "localhost:9999"; - Regions region = Regions.US_WEST_1; - - Read read = - KinesisIO.read() - .withAWSClientsProvider(ACCESS_KEY_ID, SECRET_ACCESS_KEY, region, customEndpoint); - - assertEquals( - read.getAWSClientsProvider(), - new BasicKinesisProvider( - new AWSStaticCredentialsProvider( - new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY)), - region, - customEndpoint, - true)); - } - - @Test - public void testReadWithCredentialsProviderAndCustomEndpoint() { - String customEndpoint = "localhost:9999"; - Regions region = Regions.US_WEST_1; - AWSCredentialsProvider credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); - - Read read = - KinesisIO.read().withAWSClientsProvider(credentialsProvider, region, customEndpoint); - - assertEquals( - read.getAWSClientsProvider(), - new BasicKinesisProvider(credentialsProvider, region, customEndpoint, true)); - } - - @Test - public void testReadWithBasicCredentialsAndVerificationDisabled() { - String customEndpoint = "localhost:9999"; - Regions region = Regions.US_WEST_1; - - Read read = - KinesisIO.read() - .withAWSClientsProvider( - ACCESS_KEY_ID, SECRET_ACCESS_KEY, region, customEndpoint, VERIFICATION_DISABLED); - - assertEquals( - read.getAWSClientsProvider(), - new BasicKinesisProvider( - new AWSStaticCredentialsProvider( - new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY)), - region, - customEndpoint, - VERIFICATION_DISABLED)); - } - - @Test - public void testReadWithCredentialsProviderAndVerificationDisabled() { - String customEndpoint = "localhost:9999"; - Regions region = Regions.US_WEST_1; - AWSCredentialsProvider credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); - - Read read = - KinesisIO.read() - .withAWSClientsProvider( - credentialsProvider, region, customEndpoint, VERIFICATION_DISABLED); - - assertEquals( - read.getAWSClientsProvider(), - new BasicKinesisProvider( - credentialsProvider, region, customEndpoint, VERIFICATION_DISABLED)); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOWriteTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOWriteTest.java deleted file mode 100644 index 6884b199a1e3..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisIOWriteTest.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.junit.Assert.assertEquals; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.regions.Regions; -import org.apache.beam.sdk.io.kinesis.KinesisIO.Write; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for non trivial builder variants of {@link KinesisIO#write()}. */ -@RunWith(JUnit4.class) -public class KinesisIOWriteTest { - private static final String ACCESS_KEY_ID = "ACCESS_KEY_ID"; - private static final String SECRET_KEY = "SECRET_KEY"; - private static final boolean VERIFICATION_DISABLED = false; - - @Test - public void testReadWithBasicCredentials() { - Regions region = Regions.US_EAST_1; - Write write = KinesisIO.write().withAWSClientsProvider(ACCESS_KEY_ID, SECRET_KEY, region); - - assertEquals( - write.getAWSClientsProvider(), - new BasicKinesisProvider( - new AWSStaticCredentialsProvider(new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_KEY)), - region, - null, - true)); - } - - @Test - public void testReadWithCredentialsProvider() { - Regions region = Regions.US_EAST_1; - AWSCredentialsProvider credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); - - Write write = KinesisIO.write().withAWSClientsProvider(credentialsProvider, region); - - assertEquals( - write.getAWSClientsProvider(), - new BasicKinesisProvider(credentialsProvider, region, null, true)); - } - - @Test - public void testReadWithBasicCredentialsAndCustomEndpoint() { - String customEndpoint = "localhost:9999"; - Regions region = Regions.US_WEST_1; - BasicAWSCredentials credentials = new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_KEY); - - Write write = - KinesisIO.write().withAWSClientsProvider(ACCESS_KEY_ID, SECRET_KEY, region, customEndpoint); - - assertEquals( - write.getAWSClientsProvider(), - new BasicKinesisProvider( - new AWSStaticCredentialsProvider(credentials), region, customEndpoint, true)); - } - - @Test - public void testReadWithCredentialsProviderAndCustomEndpoint() { - String customEndpoint = "localhost:9999"; - Regions region = Regions.US_WEST_1; - AWSCredentialsProvider credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); - - Write write = - KinesisIO.write().withAWSClientsProvider(credentialsProvider, region, customEndpoint); - - assertEquals( - write.getAWSClientsProvider(), - new BasicKinesisProvider(credentialsProvider, region, customEndpoint, true)); - } - - @Test - public void testReadWithBasicCredentialsAndVerificationDisabled() { - String customEndpoint = "localhost:9999"; - Regions region = Regions.US_WEST_1; - BasicAWSCredentials credentials = new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_KEY); - - Write write = - KinesisIO.write() - .withAWSClientsProvider( - ACCESS_KEY_ID, SECRET_KEY, region, customEndpoint, VERIFICATION_DISABLED); - - assertEquals( - write.getAWSClientsProvider(), - new BasicKinesisProvider( - new AWSStaticCredentialsProvider(credentials), - region, - customEndpoint, - VERIFICATION_DISABLED)); - } - - @Test - public void testReadWithCredentialsProviderAndVerificationDisabled() { - String customEndpoint = "localhost:9999"; - Regions region = Regions.US_WEST_1; - AWSCredentialsProvider credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); - - Write write = - KinesisIO.write() - .withAWSClientsProvider( - credentialsProvider, region, customEndpoint, VERIFICATION_DISABLED); - - assertEquals( - write.getAWSClientsProvider(), - new BasicKinesisProvider( - credentialsProvider, region, customEndpoint, VERIFICATION_DISABLED)); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java deleted file mode 100644 index 77cabe858f52..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.newArrayList; - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import java.util.List; -import org.apache.beam.sdk.Pipeline.PipelineExecutionException; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -import org.joda.time.DateTime; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests {@link AmazonKinesisMock}. */ -@RunWith(JUnit4.class) -public class KinesisMockReadTest { - - @Rule public final transient TestPipeline p = TestPipeline.create(); - - private final int noOfShards = 3; - private final int noOfEventsPerShard = 100; - - @Test - public void readsDataFromMockKinesis() { - List> testData = defaultTestData(); - verifyReadWithProvider(new AmazonKinesisMock.Provider(testData, 10), testData); - } - - @Test(expected = PipelineExecutionException.class) - public void readsDataFromMockKinesisWithLimitFailure() { - List> testData = defaultTestData(); - verifyReadWithProvider( - new AmazonKinesisMock.Provider(testData, 10).withExpectedListShardsLimitExceededException(), - testData); - } - - public void verifyReadWithProvider( - AmazonKinesisMock.Provider provider, List> testData) { - PCollection result = - p.apply( - KinesisIO.read() - .withStreamName("stream") - .withInitialPositionInStream(InitialPositionInStream.TRIM_HORIZON) - .withAWSClientsProvider(provider) - .withArrivalTimeWatermarkPolicy() - .withMaxNumRecords(noOfShards * noOfEventsPerShard)) - .apply(ParDo.of(new KinesisRecordToTestData())); - PAssert.that(result).containsInAnyOrder(Iterables.concat(testData)); - p.run(); - } - - static class KinesisRecordToTestData extends DoFn { - - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - c.output(new AmazonKinesisMock.TestData(c.element())); - } - } - - private List> defaultTestData() { - return provideTestData(noOfShards, noOfEventsPerShard); - } - - private List> provideTestData( - int noOfShards, int noOfEventsPerShard) { - - int seqNumber = 0; - - List> shardedData = newArrayList(); - for (int i = 0; i < noOfShards; ++i) { - List shardData = newArrayList(); - shardedData.add(shardData); - - DateTime arrival = DateTime.now(); - for (int j = 0; j < noOfEventsPerShard; ++j) { - arrival = arrival.plusSeconds(1); - - seqNumber++; - shardData.add( - new AmazonKinesisMock.TestData( - Integer.toString(seqNumber), arrival.toInstant(), Integer.toString(seqNumber))); - } - } - - return shardedData; - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockWriteTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockWriteTest.java deleted file mode 100644 index 33b0c3a096ab..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockWriteTest.java +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.mock; - -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.amazonaws.services.kinesis.producer.IKinesisProducer; -import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Properties; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link KinesisIO.Write}. */ -@RunWith(JUnit4.class) -public class KinesisMockWriteTest { - private static final String STREAM = "BEAM"; - private static final String PARTITION_KEY = "partitionKey"; - - @Rule public final transient TestPipeline p = TestPipeline.create(); - @Rule public final transient TestPipeline p2 = TestPipeline.create(); - @Rule public ExpectedException thrown = ExpectedException.none(); - - @Before - public void beforeTest() { - KinesisServiceMock kinesisService = KinesisServiceMock.getInstance(); - kinesisService.init(STREAM, 1); - } - - @Test - public void testWriteBuildsCorrectly() { - Properties properties = new Properties(); - properties.setProperty("KinesisEndpoint", "localhost"); - properties.setProperty("KinesisPort", "4567"); - - KinesisIO.Write write = - KinesisIO.write() - .withStreamName(STREAM) - .withPartitionKey(PARTITION_KEY) - .withPartitioner(new BasicKinesisPartitioner()) - .withAWSClientsProvider(new FakeKinesisProvider()) - .withProducerProperties(properties) - .withRetries(10); - - assertEquals(STREAM, write.getStreamName()); - assertEquals(PARTITION_KEY, write.getPartitionKey()); - assertEquals(properties, write.getProducerProperties()); - assertEquals(FakeKinesisProvider.class, write.getAWSClientsProvider().getClass()); - assertEquals(BasicKinesisPartitioner.class, write.getPartitioner().getClass()); - assertEquals(10, write.getRetries()); - - assertEquals("localhost", write.getProducerProperties().getProperty("KinesisEndpoint")); - assertEquals("4567", write.getProducerProperties().getProperty("KinesisPort")); - } - - @Test - public void testWriteValidationFailsMissingStreamName() { - KinesisIO.Write write = - KinesisIO.write() - .withPartitionKey(PARTITION_KEY) - .withAWSClientsProvider(new FakeKinesisProvider()); - - thrown.expect(IllegalArgumentException.class); - write.expand(null); - } - - @Test - public void testWriteValidationFailsMissingPartitioner() { - KinesisIO.Write write = - KinesisIO.write().withStreamName(STREAM).withAWSClientsProvider(new FakeKinesisProvider()); - - thrown.expect(IllegalArgumentException.class); - write.expand(null); - } - - @Test - public void testWriteValidationFailsPartitionerAndPartitioneKey() { - KinesisIO.Write write = - KinesisIO.write() - .withStreamName(STREAM) - .withPartitionKey(PARTITION_KEY) - .withPartitioner(new BasicKinesisPartitioner()) - .withAWSClientsProvider(new FakeKinesisProvider()); - - thrown.expect(IllegalArgumentException.class); - write.expand(null); - } - - @Test - public void testWriteValidationFailsMissingAWSClientsProvider() { - KinesisIO.Write write = - KinesisIO.write().withPartitionKey(PARTITION_KEY).withStreamName(STREAM); - - thrown.expect(IllegalArgumentException.class); - write.expand(null); - } - - @Test - public void testSetInvalidProperty() { - Properties properties = new Properties(); - properties.setProperty("KinesisPort", "qwe"); - - KinesisIO.Write write = - KinesisIO.write() - .withStreamName(STREAM) - .withPartitionKey(PARTITION_KEY) - .withAWSClientsProvider(new FakeKinesisProvider()) - .withProducerProperties(properties); - - thrown.expect(IllegalArgumentException.class); - write.expand(null); - } - - @Test - public void testWrite() { - KinesisServiceMock kinesisService = KinesisServiceMock.getInstance(); - - Properties properties = new Properties(); - properties.setProperty("KinesisEndpoint", "localhost"); - properties.setProperty("KinesisPort", "4567"); - properties.setProperty("VerifyCertificate", "false"); - - Iterable data = - ImmutableList.of( - "1".getBytes(StandardCharsets.UTF_8), - "2".getBytes(StandardCharsets.UTF_8), - "3".getBytes(StandardCharsets.UTF_8)); - p.apply(Create.of(data)) - .apply( - KinesisIO.write() - .withStreamName(STREAM) - .withPartitionKey(PARTITION_KEY) - .withAWSClientsProvider(new FakeKinesisProvider()) - .withProducerProperties(properties)); - p.run().waitUntilFinish(); - - assertEquals(3, kinesisService.getAddedRecords().get()); - } - - @Test - public void testWriteFailed() { - Iterable data = ImmutableList.of("1".getBytes(StandardCharsets.UTF_8)); - p.apply(Create.of(data)) - .apply( - KinesisIO.write() - .withStreamName(STREAM) - .withPartitionKey(PARTITION_KEY) - .withAWSClientsProvider(new FakeKinesisProvider().setFailedFlush(true)) - .withRetries(2)); - - thrown.expect(RuntimeException.class); - p.run().waitUntilFinish(); - } - - @Test - public void testWriteAndReadFromMockKinesis() { - KinesisServiceMock kinesisService = KinesisServiceMock.getInstance(); - - Iterable data = - ImmutableList.of( - "1".getBytes(StandardCharsets.UTF_8), "2".getBytes(StandardCharsets.UTF_8)); - p.apply(Create.of(data)) - .apply( - KinesisIO.write() - .withStreamName(STREAM) - .withPartitionKey(PARTITION_KEY) - .withAWSClientsProvider(new FakeKinesisProvider())); - p.run().waitUntilFinish(); - assertEquals(2, kinesisService.getAddedRecords().get()); - - List> testData = kinesisService.getShardedData(); - - int noOfShards = 1; - int noOfEventsPerShard = 2; - PCollection result = - p2.apply( - KinesisIO.read() - .withStreamName(STREAM) - .withInitialPositionInStream(InitialPositionInStream.TRIM_HORIZON) - .withAWSClientsProvider(new AmazonKinesisMock.Provider(testData, 10)) - .withMaxNumRecords(noOfShards * noOfEventsPerShard)) - .apply(ParDo.of(new KinesisMockReadTest.KinesisRecordToTestData())); - PAssert.that(result).containsInAnyOrder(Iterables.concat(testData)); - p2.run().waitUntilFinish(); - } - - private static final class BasicKinesisPartitioner implements KinesisPartitioner { - @Override - public String getPartitionKey(byte[] value) { - return String.valueOf(value.length); - } - - @Override - public String getExplicitHashKey(byte[] value) { - return null; - } - } - - private static final class FakeKinesisProvider implements AWSClientsProvider { - private boolean isFailedFlush = false; - - public FakeKinesisProvider() {} - - public FakeKinesisProvider setFailedFlush(boolean failedFlush) { - isFailedFlush = failedFlush; - return this; - } - - @Override - public AmazonKinesis getKinesisClient() { - return mock(AmazonKinesis.class); - } - - @Override - public AmazonCloudWatch getCloudWatchClient() { - throw new RuntimeException("Not implemented"); - } - - @Override - public IKinesisProducer createKinesisProducer(KinesisProducerConfiguration config) { - return new KinesisProducerMock(config, isFailedFlush); - } - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisProducerMock.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisProducerMock.java deleted file mode 100644 index 17c8c1ddb815..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisProducerMock.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import com.amazonaws.services.kinesis.producer.IKinesisProducer; -import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; -import com.amazonaws.services.kinesis.producer.Metric; -import com.amazonaws.services.kinesis.producer.UserRecord; -import com.amazonaws.services.kinesis.producer.UserRecordResult; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.atomic.AtomicInteger; -import org.joda.time.DateTime; - -/** Simple mock implementation of {@link IKinesisProducer} for testing. */ -public class KinesisProducerMock implements IKinesisProducer { - - private boolean isFailedFlush = false; - - private List addedRecords = Collections.synchronizedList(new ArrayList<>()); - - private KinesisServiceMock kinesisService = KinesisServiceMock.getInstance(); - - private AtomicInteger seqNumber = new AtomicInteger(0); - - public KinesisProducerMock() {} - - public KinesisProducerMock(KinesisProducerConfiguration config, boolean isFailedFlush) { - this.isFailedFlush = isFailedFlush; - this.seqNumber.set(0); - } - - @Override - public ListenableFuture addUserRecord( - String stream, String partitionKey, ByteBuffer data) { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public ListenableFuture addUserRecord(UserRecord userRecord) { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public synchronized ListenableFuture addUserRecord( - String stream, String partitionKey, String explicitHashKey, ByteBuffer data) { - seqNumber.incrementAndGet(); - SettableFuture f = SettableFuture.create(); - f.set( - new UserRecordResult( - new ArrayList<>(), String.valueOf(seqNumber.get()), explicitHashKey, !isFailedFlush)); - - if (kinesisService.getExistedStream().equals(stream)) { - addedRecords.add(new UserRecord(stream, partitionKey, explicitHashKey, data)); - } - return f; - } - - @Override - public int getOutstandingRecordsCount() { - return addedRecords.size(); - } - - @Override - public List getMetrics(String metricName, int windowSeconds) - throws InterruptedException, ExecutionException { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public List getMetrics(String metricName) - throws InterruptedException, ExecutionException { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public List getMetrics() throws InterruptedException, ExecutionException { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public List getMetrics(int windowSeconds) - throws InterruptedException, ExecutionException { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public void destroy() {} - - @Override - public void flush(String stream) { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public synchronized void flush() { - DateTime arrival = DateTime.now(); - for (int i = 0; i < addedRecords.size(); i++) { - UserRecord record = addedRecords.get(i); - arrival = arrival.plusSeconds(1); - kinesisService.addShardedData(record.getData(), arrival); - addedRecords.remove(i); - } - } - - @Override - public synchronized void flushSync() { - flush(); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java deleted file mode 100644 index 61212fb05570..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; - -import java.util.Iterator; -import java.util.List; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; - -/** * */ -@RunWith(MockitoJUnitRunner.class) -public class KinesisReaderCheckpointTest { - - @Mock private ShardCheckpoint a, b, c; - - private KinesisReaderCheckpoint checkpoint; - - @Before - public void setUp() { - checkpoint = new KinesisReaderCheckpoint(asList(a, b, c)); - } - - @Test - public void splitsCheckpointAccordingly() { - verifySplitInto(1); - verifySplitInto(2); - verifySplitInto(3); - verifySplitInto(4); - } - - @Test(expected = UnsupportedOperationException.class) - public void isImmutable() { - Iterator iterator = checkpoint.iterator(); - iterator.remove(); - } - - private void verifySplitInto(int size) { - List split = checkpoint.splitInto(size); - assertThat(Iterables.concat(split)).containsOnly(a, b, c); - assertThat(split).hasSize(Math.min(size, 3)); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java deleted file mode 100644 index 64f0fe7c6538..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; - -import java.io.IOException; -import java.util.NoSuchElementException; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; - -/** Tests {@link KinesisReader}. */ -@RunWith(MockitoJUnitRunner.Silent.class) -public class KinesisReaderTest { - - @Mock private SimplifiedKinesisClient kinesis; - @Mock private CheckpointGenerator generator; - @Mock private ShardCheckpoint firstCheckpoint, secondCheckpoint; - @Mock private KinesisRecord a, b, c, d; - @Mock private KinesisSource kinesisSource; - @Mock private ShardReadersPool shardReadersPool; - - private KinesisReader reader; - - @Before - public void setUp() throws TransientKinesisException { - when(generator.generate(kinesis)) - .thenReturn(new KinesisReaderCheckpoint(asList(firstCheckpoint, secondCheckpoint))); - when(shardReadersPool.nextRecord()).thenReturn(CustomOptional.absent()); - when(a.getApproximateArrivalTimestamp()).thenReturn(Instant.now()); - when(b.getApproximateArrivalTimestamp()).thenReturn(Instant.now()); - when(c.getApproximateArrivalTimestamp()).thenReturn(Instant.now()); - when(d.getApproximateArrivalTimestamp()).thenReturn(Instant.now()); - - reader = spy(createReader(Duration.ZERO)); - } - - private KinesisReader createReader(Duration backlogBytesCheckThreshold) { - return new KinesisReader( - kinesis, - generator, - kinesisSource, - WatermarkPolicyFactory.withArrivalTimePolicy(), - RateLimitPolicyFactory.withoutLimiter(), - Duration.ZERO, - backlogBytesCheckThreshold, - ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD) { - @Override - ShardReadersPool createShardReadersPool() { - return shardReadersPool; - } - }; - } - - @Test - public void startReturnsFalseIfNoDataAtTheBeginning() throws IOException { - assertThat(reader.start()).isFalse(); - } - - @Test(expected = NoSuchElementException.class) - public void throwsNoSuchElementExceptionIfNoData() throws IOException { - reader.start(); - reader.getCurrent(); - } - - @Test - public void startReturnsTrueIfSomeDataAvailable() throws IOException { - when(shardReadersPool.nextRecord()) - .thenReturn(CustomOptional.of(a)) - .thenReturn(CustomOptional.absent()); - - assertThat(reader.start()).isTrue(); - } - - @Test - public void readsThroughAllDataAvailable() throws IOException { - when(shardReadersPool.nextRecord()) - .thenReturn(CustomOptional.of(c)) - .thenReturn(CustomOptional.absent()) - .thenReturn(CustomOptional.of(a)) - .thenReturn(CustomOptional.absent()) - .thenReturn(CustomOptional.of(d)) - .thenReturn(CustomOptional.of(b)) - .thenReturn(CustomOptional.absent()); - - assertThat(reader.start()).isTrue(); - assertThat(reader.getCurrent()).isEqualTo(c); - assertThat(reader.advance()).isFalse(); - assertThat(reader.advance()).isTrue(); - assertThat(reader.getCurrent()).isEqualTo(a); - assertThat(reader.advance()).isFalse(); - assertThat(reader.advance()).isTrue(); - assertThat(reader.getCurrent()).isEqualTo(d); - assertThat(reader.advance()).isTrue(); - assertThat(reader.getCurrent()).isEqualTo(b); - assertThat(reader.advance()).isFalse(); - } - - @Test - public void returnsCurrentWatermark() throws IOException { - Instant expectedWatermark = new Instant(123456L); - when(shardReadersPool.getWatermark()).thenReturn(expectedWatermark); - - reader.start(); - Instant currentWatermark = reader.getWatermark(); - - assertThat(currentWatermark).isEqualTo(expectedWatermark); - } - - @Test - public void getSplitBacklogBytesShouldReturnLastSeenValueWhenKinesisExceptionsOccur() - throws TransientKinesisException, IOException { - reader.start(); - when(kinesisSource.getStreamName()).thenReturn("stream1"); - when(shardReadersPool.getLatestRecordTimestamp()) - .thenReturn(Instant.now().minus(Duration.standardMinutes(1))); - when(kinesis.getBacklogBytes(eq("stream1"), any(Instant.class))) - .thenReturn(10L) - .thenThrow(TransientKinesisException.class) - .thenReturn(20L); - - assertThat(reader.getSplitBacklogBytes()).isEqualTo(10); - assertThat(reader.getSplitBacklogBytes()).isEqualTo(10); - assertThat(reader.getSplitBacklogBytes()).isEqualTo(20); - } - - @Test - public void getSplitBacklogBytesShouldReturnLastSeenValueWhenCalledFrequently() - throws TransientKinesisException, IOException { - KinesisReader backlogCachingReader = spy(createReader(Duration.standardSeconds(30))); - backlogCachingReader.start(); - when(shardReadersPool.getLatestRecordTimestamp()) - .thenReturn(Instant.now().minus(Duration.standardMinutes(1))); - when(kinesisSource.getStreamName()).thenReturn("stream1"); - when(kinesis.getBacklogBytes(eq("stream1"), any(Instant.class))) - .thenReturn(10L) - .thenReturn(20L); - - assertThat(backlogCachingReader.getSplitBacklogBytes()).isEqualTo(10); - assertThat(backlogCachingReader.getSplitBacklogBytes()).isEqualTo(10); - } - - @Test - public void getSplitBacklogBytesShouldReturnBacklogUnknown() - throws IOException, TransientKinesisException { - reader.start(); - when(kinesisSource.getStreamName()).thenReturn("stream1"); - when(shardReadersPool.getLatestRecordTimestamp()) - .thenReturn(BoundedWindow.TIMESTAMP_MIN_VALUE) - .thenReturn(Instant.now().minus(Duration.standardMinutes(1))); - when(kinesis.getBacklogBytes(eq("stream1"), any(Instant.class))).thenReturn(10L); - - assertThat(reader.getSplitBacklogBytes()) - .isEqualTo(UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN); - assertThat(reader.getSplitBacklogBytes()).isEqualTo(10); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java deleted file mode 100644 index 7df3643050ba..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import org.apache.beam.sdk.testing.CoderProperties; -import org.joda.time.Instant; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests {@link KinesisRecordCoder}. */ -@RunWith(JUnit4.class) -public class KinesisRecordCoderTest { - - @Test - public void encodingAndDecodingWorks() throws Exception { - KinesisRecord record = - new KinesisRecord( - ByteBuffer.wrap("data".getBytes(StandardCharsets.UTF_8)), - "sequence", - 128L, - "partition", - Instant.now(), - Instant.now(), - "stream", - "shard"); - CoderProperties.coderDecodeEncodeEqual(new KinesisRecordCoder(), record); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisServiceMock.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisServiceMock.java deleted file mode 100644 index dcbe4224b630..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisServiceMock.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists.newArrayList; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; -import org.joda.time.DateTime; - -/** Simple mock implementation of Kinesis service for testing, singletone. */ -public class KinesisServiceMock { - private static KinesisServiceMock instance; - - // Mock stream where client is supposed to write - private String existedStream; - - private AtomicInteger addedRecords = new AtomicInteger(0); - private AtomicInteger seqNumber = new AtomicInteger(0); - private List> shardedData; - - private KinesisServiceMock() {} - - public static synchronized KinesisServiceMock getInstance() { - if (instance == null) { - instance = new KinesisServiceMock(); - } - return instance; - } - - public synchronized void init(String stream, int shardsNum) { - existedStream = stream; - addedRecords.set(0); - seqNumber.set(0); - shardedData = newArrayList(); - for (int i = 0; i < shardsNum; i++) { - List shardData = newArrayList(); - shardedData.add(shardData); - } - } - - public AtomicInteger getAddedRecords() { - return addedRecords; - } - - public String getExistedStream() { - return existedStream; - } - - public synchronized void addShardedData(ByteBuffer data, DateTime arrival) { - String dataString = StandardCharsets.UTF_8.decode(data).toString(); - - List shardData = shardedData.get(0); - - seqNumber.incrementAndGet(); - AmazonKinesisMock.TestData testData = - new AmazonKinesisMock.TestData( - dataString, arrival.toInstant(), Integer.toString(seqNumber.get())); - shardData.add(testData); - - addedRecords.incrementAndGet(); - } - - public synchronized List> getShardedData() { - return shardedData; - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisTestOptions.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisTestOptions.java deleted file mode 100644 index 2ba932b35a3d..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisTestOptions.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.Description; -import org.apache.beam.sdk.testing.TestPipelineOptions; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** Options for Kinesis integration tests. */ -public interface KinesisTestOptions extends TestPipelineOptions { - - @Description("AWS region where Kinesis stream resided") - @Default.String("aws-kinesis-region") - String getAwsKinesisRegion(); - - void setAwsKinesisRegion(String value); - - @Description("Kinesis stream name") - @Default.String("aws-kinesis-stream") - String getAwsKinesisStream(); - - void setAwsKinesisStream(String value); - - @Description("AWS secret key") - @Default.String("aws-secret-key") - String getAwsSecretKey(); - - void setAwsSecretKey(String value); - - @Description("AWS access key") - @Default.String("aws-access-key") - String getAwsAccessKey(); - - void setAwsAccessKey(String value); - - @Description("Aws service endpoint") - @Nullable - String getAwsServiceEndpoint(); - - void setAwsServiceEndpoint(String awsServiceEndpoint); - - @Description("Flag for certificate verification") - @Default.Boolean(true) - Boolean getAwsVerifyCertificate(); - - void setAwsVerifyCertificate(Boolean awsVerifyCertificate); - - @Description("Number of shards of stream") - @Default.Integer(2) - Integer getNumberOfShards(); - - void setNumberOfShards(Integer count); - - @Description("Number of records that will be written and read by the test") - @Default.Integer(1000) - Integer getNumberOfRecords(); - - void setNumberOfRecords(Integer count); - - @Description("Use localstack. Disable to test with real Kinesis") - @Default.Boolean(true) - Boolean getUseLocalstack(); - - void setUseLocalstack(Boolean useLocalstack); - - @Description("Create stream. Enabled when using localstack") - @Default.Boolean(false) - Boolean getCreateStream(); - - void setCreateStream(Boolean createStream); -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java deleted file mode 100644 index 0d144d19a909..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.apache.beam.sdk.io.kinesis.RateLimitPolicyFactory.withDefaultRateLimiter; -import static org.assertj.core.api.Assertions.assertThat; -import static org.joda.time.Duration.millis; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.clearInvocations; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; -import static org.powermock.api.mockito.PowerMockito.verifyStatic; - -import java.util.concurrent.atomic.AtomicLong; -import org.apache.beam.sdk.io.kinesis.RateLimitPolicyFactory.DefaultRateLimiter; -import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.sdk.util.Sleeper; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.api.mockito.PowerMockito; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -@RunWith(PowerMockRunner.class) -@PrepareForTest(RateLimitPolicyFactory.class) -public class RateLimitPolicyFactoryTest { - - @Test - public void defaultRateLimiterShouldUseBackoffs() throws Exception { - assertThat(withDefaultRateLimiter().getRateLimitPolicy()) - .isInstanceOf(DefaultRateLimiter.class); - assertThat(withDefaultRateLimiter(millis(1), millis(1), millis(1)).getRateLimitPolicy()) - .isInstanceOf(DefaultRateLimiter.class); - - Sleeper sleeper = mock(Sleeper.class); - BackOff emptySuccess = mock(BackOff.class); - BackOff throttled = mock(BackOff.class); - - RateLimitPolicy policy = new DefaultRateLimiter(emptySuccess, throttled, sleeper); - - // reset emptySuccess after receiving at least 1 record, throttled is reset on any success - policy.onSuccess(ImmutableList.of(mock(KinesisRecord.class))); - - verify(emptySuccess).reset(); - verify(throttled).reset(); - verifyNoInteractions(sleeper); - clearInvocations(emptySuccess, throttled); - - when(emptySuccess.nextBackOffMillis()).thenReturn(88L, 99L); - // throttle if no records received, throttled is reset again - policy.onSuccess(ImmutableList.of()); - policy.onSuccess(ImmutableList.of()); - - verify(emptySuccess, times(2)).nextBackOffMillis(); - verify(throttled, times(2)).reset(); - verify(sleeper).sleep(88L); - verify(sleeper).sleep(99L); - verifyNoMoreInteractions(sleeper, throttled, emptySuccess); - clearInvocations(emptySuccess, throttled, sleeper); - - when(throttled.nextBackOffMillis()).thenReturn(111L, 222L); - // throttle onThrottle - policy.onThrottle(mock(KinesisClientThrottledException.class)); - policy.onThrottle(mock(KinesisClientThrottledException.class)); - - verify(throttled, times(2)).nextBackOffMillis(); - verify(sleeper).sleep(111L); - verify(sleeper).sleep(222L); - verifyNoMoreInteractions(sleeper, throttled, emptySuccess); - } - - @Test - public void withoutLimiterShouldDoNothing() throws Exception { - PowerMockito.spy(Thread.class); - PowerMockito.doNothing().when(Thread.class); - Thread.sleep(anyLong()); - RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withoutLimiter().getRateLimitPolicy(); - rateLimitPolicy.onSuccess(ImmutableList.of()); - verifyStatic(Thread.class, never()); - Thread.sleep(anyLong()); - } - - @Test - public void shouldDelayDefaultInterval() throws Exception { - PowerMockito.spy(Thread.class); - PowerMockito.doNothing().when(Thread.class); - Thread.sleep(anyLong()); - RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withFixedDelay().getRateLimitPolicy(); - rateLimitPolicy.onSuccess(ImmutableList.of()); - verifyStatic(Thread.class); - Thread.sleep(eq(1000L)); - } - - @Test - public void shouldDelayFixedInterval() throws Exception { - PowerMockito.spy(Thread.class); - PowerMockito.doNothing().when(Thread.class); - Thread.sleep(anyLong()); - RateLimitPolicy rateLimitPolicy = - RateLimitPolicyFactory.withFixedDelay(millis(500)).getRateLimitPolicy(); - rateLimitPolicy.onSuccess(ImmutableList.of()); - verifyStatic(Thread.class); - Thread.sleep(eq(500L)); - } - - @Test - public void shouldDelayDynamicInterval() throws Exception { - PowerMockito.spy(Thread.class); - PowerMockito.doNothing().when(Thread.class); - Thread.sleep(anyLong()); - AtomicLong delay = new AtomicLong(0L); - RateLimitPolicy rateLimitPolicy = - RateLimitPolicyFactory.withDelay(() -> millis(delay.getAndUpdate(d -> d ^ 1))) - .getRateLimitPolicy(); - rateLimitPolicy.onSuccess(ImmutableList.of()); - verifyStatic(Thread.class); - Thread.sleep(eq(0L)); - Thread.sleep(eq(1L)); - Thread.sleep(eq(0L)); - Thread.sleep(eq(1L)); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java deleted file mode 100644 index ad1e58c265e7..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.mockito.Mockito.when; - -import java.util.Collections; -import java.util.List; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; -import org.assertj.core.api.Assertions; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; - -/** * */ -@RunWith(MockitoJUnitRunner.class) -public class RecordFilterTest { - - @Mock private ShardCheckpoint checkpoint; - @Mock private KinesisRecord record1, record2, record3, record4, record5; - - @Test - public void shouldFilterOutRecordsBeforeOrAtCheckpoint() { - when(checkpoint.isBeforeOrAt(record1)).thenReturn(false); - when(checkpoint.isBeforeOrAt(record2)).thenReturn(true); - when(checkpoint.isBeforeOrAt(record3)).thenReturn(true); - when(checkpoint.isBeforeOrAt(record4)).thenReturn(false); - when(checkpoint.isBeforeOrAt(record5)).thenReturn(true); - List records = Lists.newArrayList(record1, record2, record3, record4, record5); - RecordFilter underTest = new RecordFilter(); - - List retainedRecords = underTest.apply(records, checkpoint); - - Assertions.assertThat(retainedRecords).containsOnly(record2, record3, record5); - } - - @Test - public void shouldNotFailOnEmptyList() { - List records = Collections.emptyList(); - RecordFilter underTest = new RecordFilter(); - - List retainedRecords = underTest.apply(records, checkpoint); - - Assertions.assertThat(retainedRecords).isEmpty(); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java deleted file mode 100644 index 227542cb8055..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream.LATEST; -import static com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream.TRIM_HORIZON; -import static com.amazonaws.services.kinesis.model.ShardIteratorType.AFTER_SEQUENCE_NUMBER; -import static com.amazonaws.services.kinesis.model.ShardIteratorType.AT_SEQUENCE_NUMBER; -import static com.amazonaws.services.kinesis.model.ShardIteratorType.AT_TIMESTAMP; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Matchers.eq; -import static org.mockito.Matchers.isNull; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; -import com.amazonaws.services.kinesis.model.ShardIteratorType; -import java.io.IOException; -import org.joda.time.DateTime; -import org.joda.time.Instant; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; - -/** */ -@RunWith(MockitoJUnitRunner.class) -public class ShardCheckpointTest { - - private static final String AT_SEQUENCE_SHARD_IT = "AT_SEQUENCE_SHARD_IT"; - private static final String AFTER_SEQUENCE_SHARD_IT = "AFTER_SEQUENCE_SHARD_IT"; - private static final String STREAM_NAME = "STREAM"; - private static final String SHARD_ID = "SHARD_ID"; - @Mock private SimplifiedKinesisClient client; - - @Before - public void setUp() throws IOException, TransientKinesisException { - when(client.getShardIterator( - eq(STREAM_NAME), - eq(SHARD_ID), - eq(AT_SEQUENCE_NUMBER), - anyString(), - isNull(Instant.class))) - .thenReturn(AT_SEQUENCE_SHARD_IT); - when(client.getShardIterator( - eq(STREAM_NAME), - eq(SHARD_ID), - eq(AFTER_SEQUENCE_NUMBER), - anyString(), - isNull(Instant.class))) - .thenReturn(AFTER_SEQUENCE_SHARD_IT); - } - - @Test - public void testProvidingShardIterator() throws IOException, TransientKinesisException { - assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", null).getShardIterator(client)) - .isEqualTo(AT_SEQUENCE_SHARD_IT); - assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", null).getShardIterator(client)) - .isEqualTo(AFTER_SEQUENCE_SHARD_IT); - assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", 10L).getShardIterator(client)) - .isEqualTo(AT_SEQUENCE_SHARD_IT); - assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", 10L).getShardIterator(client)) - .isEqualTo(AT_SEQUENCE_SHARD_IT); - } - - @Test - public void testComparisonWithExtendedSequenceNumber() { - assertThat( - new ShardCheckpoint("", "", new StartingPoint(LATEST)) - .isBeforeOrAt(recordWith(new ExtendedSequenceNumber("100", 0L)))) - .isTrue(); - - assertThat( - new ShardCheckpoint("", "", new StartingPoint(TRIM_HORIZON)) - .isBeforeOrAt(recordWith(new ExtendedSequenceNumber("100", 0L)))) - .isTrue(); - - assertThat( - checkpoint(AFTER_SEQUENCE_NUMBER, "10", 1L) - .isBeforeOrAt(recordWith(new ExtendedSequenceNumber("100", 0L)))) - .isTrue(); - - assertThat( - checkpoint(AT_SEQUENCE_NUMBER, "100", 0L) - .isBeforeOrAt(recordWith(new ExtendedSequenceNumber("100", 0L)))) - .isTrue(); - - assertThat( - checkpoint(AFTER_SEQUENCE_NUMBER, "100", 0L) - .isBeforeOrAt(recordWith(new ExtendedSequenceNumber("100", 0L)))) - .isFalse(); - - assertThat( - checkpoint(AT_SEQUENCE_NUMBER, "100", 1L) - .isBeforeOrAt(recordWith(new ExtendedSequenceNumber("100", 0L)))) - .isFalse(); - - assertThat( - checkpoint(AFTER_SEQUENCE_NUMBER, "100", 0L) - .isBeforeOrAt(recordWith(new ExtendedSequenceNumber("99", 1L)))) - .isFalse(); - } - - @Test - public void testComparisonWithTimestamp() { - DateTime referenceTimestamp = DateTime.now(); - - assertThat( - checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) - .isBeforeOrAt(recordWith(referenceTimestamp.minusMillis(10).toInstant()))) - .isFalse(); - - assertThat( - checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) - .isBeforeOrAt(recordWith(referenceTimestamp.toInstant()))) - .isTrue(); - - assertThat( - checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) - .isBeforeOrAt(recordWith(referenceTimestamp.plusMillis(10).toInstant()))) - .isTrue(); - } - - private KinesisRecord recordWith(ExtendedSequenceNumber extendedSequenceNumber) { - KinesisRecord record = mock(KinesisRecord.class); - when(record.getExtendedSequenceNumber()).thenReturn(extendedSequenceNumber); - return record; - } - - private ShardCheckpoint checkpoint( - ShardIteratorType iteratorType, String sequenceNumber, Long subSequenceNumber) { - return new ShardCheckpoint( - STREAM_NAME, SHARD_ID, iteratorType, sequenceNumber, subSequenceNumber); - } - - private KinesisRecord recordWith(Instant approximateArrivalTimestamp) { - KinesisRecord record = mock(KinesisRecord.class); - when(record.getApproximateArrivalTimestamp()).thenReturn(approximateArrivalTimestamp); - return record; - } - - private ShardCheckpoint checkpoint(ShardIteratorType iteratorType, Instant timestamp) { - return new ShardCheckpoint(STREAM_NAME, SHARD_ID, iteratorType, timestamp); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java deleted file mode 100644 index 74c9446d316a..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java +++ /dev/null @@ -1,355 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static java.util.Collections.singletonList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.timeout; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.TimeUnit; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; -import org.mockito.stubbing.Answer; - -/** Tests {@link ShardReadersPool}. */ -@RunWith(MockitoJUnitRunner.Silent.class) -public class ShardReadersPoolTest { - - private static final int TIMEOUT_IN_MILLIS = (int) TimeUnit.SECONDS.toMillis(10); - - @Mock private ShardRecordsIterator firstIterator, secondIterator, thirdIterator, fourthIterator; - @Mock private ShardCheckpoint firstCheckpoint, secondCheckpoint; - @Mock private SimplifiedKinesisClient kinesis; - @Mock private KinesisRecord a, b, c, d; - @Mock private WatermarkPolicyFactory watermarkPolicyFactory; - @Mock private RateLimitPolicyFactory rateLimitPolicyFactory; - @Mock private RateLimitPolicy customRateLimitPolicy; - - private ShardReadersPool shardReadersPool; - private final Instant now = Instant.now(); - - @Before - public void setUp() throws TransientKinesisException { - when(a.getShardId()).thenReturn("shard1"); - when(b.getShardId()).thenReturn("shard1"); - when(c.getShardId()).thenReturn("shard2"); - when(d.getShardId()).thenReturn("shard2"); - when(firstCheckpoint.getShardId()).thenReturn("shard1"); - when(firstCheckpoint.getStreamName()).thenReturn("testStream"); - when(secondCheckpoint.getShardId()).thenReturn("shard2"); - when(firstIterator.getShardId()).thenReturn("shard1"); - when(firstIterator.getStreamName()).thenReturn("testStream"); - when(firstIterator.getCheckpoint()).thenReturn(firstCheckpoint); - when(secondIterator.getShardId()).thenReturn("shard2"); - when(secondIterator.getCheckpoint()).thenReturn(secondCheckpoint); - when(thirdIterator.getShardId()).thenReturn("shard3"); - when(fourthIterator.getShardId()).thenReturn("shard4"); - - WatermarkPolicy watermarkPolicy = - WatermarkPolicyFactory.withArrivalTimePolicy().createWatermarkPolicy(); - RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withoutLimiter().getRateLimitPolicy(); - - KinesisReaderCheckpoint checkpoint = - new KinesisReaderCheckpoint(ImmutableList.of(firstCheckpoint, secondCheckpoint)); - shardReadersPool = - Mockito.spy( - new ShardReadersPool( - kinesis, checkpoint, watermarkPolicyFactory, rateLimitPolicyFactory, 100)); - - when(watermarkPolicyFactory.createWatermarkPolicy()).thenReturn(watermarkPolicy); - when(rateLimitPolicyFactory.getRateLimitPolicy()).thenReturn(rateLimitPolicy); - - doReturn(firstIterator).when(shardReadersPool).createShardIterator(kinesis, firstCheckpoint); - doReturn(secondIterator).when(shardReadersPool).createShardIterator(kinesis, secondCheckpoint); - } - - @After - public void clean() { - shardReadersPool.stop(); - } - - @Test - public void shouldReturnAllRecords() - throws TransientKinesisException, KinesisShardClosedException { - when(firstIterator.readNextBatch()) - .thenReturn(Collections.emptyList()) - .thenReturn(ImmutableList.of(a, b)) - .thenReturn(Collections.emptyList()); - when(secondIterator.readNextBatch()) - .thenReturn(singletonList(c)) - .thenReturn(singletonList(d)) - .thenReturn(Collections.emptyList()); - - shardReadersPool.start(); - List fetchedRecords = new ArrayList<>(); - while (fetchedRecords.size() < 4) { - CustomOptional nextRecord = shardReadersPool.nextRecord(); - if (nextRecord.isPresent()) { - fetchedRecords.add(nextRecord.get()); - } - } - assertThat(fetchedRecords).containsExactlyInAnyOrder(a, b, c, d); - assertThat(shardReadersPool.getRecordsQueue().remainingCapacity()).isEqualTo(100 * 2); - } - - @Test - public void shouldReturnAbsentOptionalWhenNoRecords() - throws TransientKinesisException, KinesisShardClosedException { - when(firstIterator.readNextBatch()).thenReturn(Collections.emptyList()); - when(secondIterator.readNextBatch()).thenReturn(Collections.emptyList()); - - shardReadersPool.start(); - CustomOptional nextRecord = shardReadersPool.nextRecord(); - assertThat(nextRecord.isPresent()).isFalse(); - } - - @Test - public void shouldCheckpointReadRecords() - throws TransientKinesisException, KinesisShardClosedException { - when(firstIterator.readNextBatch()) - .thenReturn(ImmutableList.of(a, b)) - .thenReturn(Collections.emptyList()); - when(secondIterator.readNextBatch()) - .thenReturn(singletonList(c)) - .thenReturn(singletonList(d)) - .thenReturn(Collections.emptyList()); - - shardReadersPool.start(); - int recordsFound = 0; - while (recordsFound < 4) { - CustomOptional nextRecord = shardReadersPool.nextRecord(); - if (nextRecord.isPresent()) { - recordsFound++; - KinesisRecord kinesisRecord = nextRecord.get(); - if ("shard1".equals(kinesisRecord.getShardId())) { - verify(firstIterator).ackRecord(kinesisRecord); - } else { - verify(secondIterator).ackRecord(kinesisRecord); - } - } - } - } - - @Test - public void shouldInterruptKinesisReadingAndStopShortly() - throws TransientKinesisException, KinesisShardClosedException { - when(firstIterator.readNextBatch()) - .thenAnswer( - (Answer>) - invocation -> { - Thread.sleep(TIMEOUT_IN_MILLIS / 2); - return Collections.emptyList(); - }); - shardReadersPool.start(); - - Stopwatch stopwatch = Stopwatch.createStarted(); - shardReadersPool.stop(); - assertThat(stopwatch.elapsed(TimeUnit.MILLISECONDS)).isLessThan(TIMEOUT_IN_MILLIS); - } - - @Test - public void shouldInterruptPuttingRecordsToQueueAndStopShortly() - throws TransientKinesisException, KinesisShardClosedException { - when(firstIterator.readNextBatch()).thenReturn(ImmutableList.of(a, b, c)); - KinesisReaderCheckpoint checkpoint = - new KinesisReaderCheckpoint(ImmutableList.of(firstCheckpoint, secondCheckpoint)); - - WatermarkPolicyFactory watermarkPolicyFactory = WatermarkPolicyFactory.withArrivalTimePolicy(); - RateLimitPolicyFactory rateLimitPolicyFactory = RateLimitPolicyFactory.withoutLimiter(); - ShardReadersPool shardReadersPool = - new ShardReadersPool( - kinesis, checkpoint, watermarkPolicyFactory, rateLimitPolicyFactory, 2); - shardReadersPool.start(); - - Stopwatch stopwatch = Stopwatch.createStarted(); - shardReadersPool.stop(); - assertThat(stopwatch.elapsed(TimeUnit.MILLISECONDS)).isLessThan(TIMEOUT_IN_MILLIS); - } - - @Test - public void shouldStopReadingShardAfterReceivingShardClosedException() throws Exception { - when(firstIterator.readNextBatch()).thenThrow(KinesisShardClosedException.class); - when(firstIterator.findSuccessiveShardRecordIterators()).thenReturn(Collections.emptyList()); - - shardReadersPool.start(); - - verify(firstIterator, timeout(TIMEOUT_IN_MILLIS).times(1)).readNextBatch(); - verify(secondIterator, timeout(TIMEOUT_IN_MILLIS).atLeast(2)).readNextBatch(); - } - - @Test - public void shouldStartReadingSuccessiveShardsAfterReceivingShardClosedException() - throws Exception { - when(firstIterator.readNextBatch()).thenThrow(KinesisShardClosedException.class); - when(firstIterator.findSuccessiveShardRecordIterators()) - .thenReturn(ImmutableList.of(thirdIterator, fourthIterator)); - - shardReadersPool.start(); - - verify(thirdIterator, timeout(TIMEOUT_IN_MILLIS).atLeast(2)).readNextBatch(); - verify(fourthIterator, timeout(TIMEOUT_IN_MILLIS).atLeast(2)).readNextBatch(); - } - - @Test - public void shouldStopReadersPoolWhenLastShardReaderStopped() throws Exception { - when(firstIterator.readNextBatch()).thenThrow(KinesisShardClosedException.class); - when(firstIterator.findSuccessiveShardRecordIterators()).thenReturn(Collections.emptyList()); - - shardReadersPool.start(); - - verify(firstIterator, timeout(TIMEOUT_IN_MILLIS).times(1)).readNextBatch(); - } - - @Test - public void shouldStopReadersPoolAlsoWhenExceptionsOccurDuringStopping() throws Exception { - when(firstIterator.readNextBatch()).thenThrow(KinesisShardClosedException.class); - when(firstIterator.findSuccessiveShardRecordIterators()) - .thenThrow(TransientKinesisException.class) - .thenReturn(Collections.emptyList()); - - shardReadersPool.start(); - - verify(firstIterator, timeout(TIMEOUT_IN_MILLIS).times(2)).readNextBatch(); - } - - @Test - public void shouldReturnAbsentOptionalWhenStartedWithNoIterators() throws Exception { - KinesisReaderCheckpoint checkpoint = new KinesisReaderCheckpoint(Collections.emptyList()); - WatermarkPolicyFactory watermarkPolicyFactory = WatermarkPolicyFactory.withArrivalTimePolicy(); - RateLimitPolicyFactory rateLimitPolicyFactory = RateLimitPolicyFactory.withoutLimiter(); - shardReadersPool = - Mockito.spy( - new ShardReadersPool( - kinesis, - checkpoint, - watermarkPolicyFactory, - rateLimitPolicyFactory, - ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD)); - doReturn(firstIterator) - .when(shardReadersPool) - .createShardIterator(eq(kinesis), any(ShardCheckpoint.class)); - - shardReadersPool.start(); - - assertThat(shardReadersPool.nextRecord()).isEqualTo(CustomOptional.absent()); - } - - @Test - public void shouldForgetClosedShardIterator() throws Exception { - when(firstIterator.readNextBatch()).thenThrow(KinesisShardClosedException.class); - List emptyList = Collections.emptyList(); - when(firstIterator.findSuccessiveShardRecordIterators()).thenReturn(emptyList); - - shardReadersPool.start(); - verify(shardReadersPool) - .startReadingShards(ImmutableList.of(firstIterator, secondIterator), "testStream"); - verify(shardReadersPool, timeout(TIMEOUT_IN_MILLIS)) - .startReadingShards(emptyList, "testStream"); - - KinesisReaderCheckpoint checkpointMark = shardReadersPool.getCheckpointMark(); - assertThat(checkpointMark.iterator()) - .extracting("shardId", String.class) - .containsOnly("shard2") - .doesNotContain("shard1"); - } - - @Test - public void shouldReturnTheLeastWatermarkOfAllShards() throws TransientKinesisException { - Instant threeMin = now.minus(Duration.standardMinutes(3)); - Instant twoMin = now.minus(Duration.standardMinutes(2)); - - when(firstIterator.getShardWatermark()).thenReturn(threeMin).thenReturn(now); - when(secondIterator.getShardWatermark()).thenReturn(twoMin); - - shardReadersPool.start(); - - assertThat(shardReadersPool.getWatermark()).isEqualTo(threeMin); - assertThat(shardReadersPool.getWatermark()).isEqualTo(twoMin); - - verify(firstIterator, times(2)).getShardWatermark(); - verify(secondIterator, times(2)).getShardWatermark(); - } - - @Test - public void shouldReturnTheOldestFromLatestRecordTimestampOfAllShards() - throws TransientKinesisException { - Instant threeMin = now.minus(Duration.standardMinutes(3)); - Instant twoMin = now.minus(Duration.standardMinutes(2)); - - when(firstIterator.getLatestRecordTimestamp()).thenReturn(threeMin).thenReturn(now); - when(secondIterator.getLatestRecordTimestamp()).thenReturn(twoMin); - - shardReadersPool.start(); - - assertThat(shardReadersPool.getLatestRecordTimestamp()).isEqualTo(threeMin); - assertThat(shardReadersPool.getLatestRecordTimestamp()).isEqualTo(twoMin); - - verify(firstIterator, times(2)).getLatestRecordTimestamp(); - verify(secondIterator, times(2)).getLatestRecordTimestamp(); - } - - @Test - public void shouldCallRateLimitPolicy() - throws TransientKinesisException, KinesisShardClosedException, InterruptedException { - KinesisClientThrottledException e = new KinesisClientThrottledException("", null); - when(firstIterator.readNextBatch()) - .thenThrow(e) - .thenReturn(ImmutableList.of(a, b)) - .thenReturn(Collections.emptyList()); - when(secondIterator.readNextBatch()) - .thenReturn(singletonList(c)) - .thenReturn(singletonList(d)) - .thenReturn(Collections.emptyList()); - when(rateLimitPolicyFactory.getRateLimitPolicy()).thenReturn(customRateLimitPolicy); - - shardReadersPool.start(); - List fetchedRecords = new ArrayList<>(); - while (fetchedRecords.size() < 4) { - CustomOptional nextRecord = shardReadersPool.nextRecord(); - if (nextRecord.isPresent()) { - fetchedRecords.add(nextRecord.get()); - } - } - - verify(customRateLimitPolicy, timeout(TIMEOUT_IN_MILLIS)).onThrottle(same(e)); - verify(customRateLimitPolicy, timeout(TIMEOUT_IN_MILLIS)).onSuccess(eq(ImmutableList.of(a, b))); - verify(customRateLimitPolicy, timeout(TIMEOUT_IN_MILLIS)).onSuccess(eq(singletonList(c))); - verify(customRateLimitPolicy, timeout(TIMEOUT_IN_MILLIS)).onSuccess(eq(singletonList(d))); - verify(customRateLimitPolicy, timeout(TIMEOUT_IN_MILLIS).atLeastOnce()) - .onSuccess(eq(Collections.emptyList())); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIteratorTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIteratorTest.java deleted file mode 100644 index 397dc9831a9a..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIteratorTest.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.Mockito.when; - -import com.amazonaws.services.kinesis.model.ExpiredIteratorException; -import java.io.IOException; -import java.util.Collections; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.junit.MockitoJUnitRunner; -import org.mockito.stubbing.Answer; - -/** Tests {@link ShardRecordsIterator}. */ -@RunWith(MockitoJUnitRunner.Silent.class) -public class ShardRecordsIteratorTest { - - private static final String INITIAL_ITERATOR = "INITIAL_ITERATOR"; - private static final String SECOND_ITERATOR = "SECOND_ITERATOR"; - private static final String SECOND_REFRESHED_ITERATOR = "SECOND_REFRESHED_ITERATOR"; - private static final String THIRD_ITERATOR = "THIRD_ITERATOR"; - private static final String STREAM_NAME = "STREAM_NAME"; - private static final String SHARD_ID = "SHARD_ID"; - private static final Instant NOW = Instant.now(); - - @Mock private SimplifiedKinesisClient kinesisClient; - @Mock private ShardCheckpoint firstCheckpoint, aCheckpoint, bCheckpoint, cCheckpoint, dCheckpoint; - @Mock private GetKinesisRecordsResult firstResult, secondResult, thirdResult; - @Mock private KinesisRecord a, b, c, d; - @Mock private RecordFilter recordFilter; - - private ShardRecordsIterator iterator; - - @Before - public void setUp() throws IOException, TransientKinesisException { - when(firstCheckpoint.getShardIterator(kinesisClient)).thenReturn(INITIAL_ITERATOR); - when(firstCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(firstCheckpoint.getShardId()).thenReturn(SHARD_ID); - - when(firstCheckpoint.moveAfter(a)).thenReturn(aCheckpoint); - when(aCheckpoint.moveAfter(b)).thenReturn(bCheckpoint); - when(aCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(aCheckpoint.getShardId()).thenReturn(SHARD_ID); - when(bCheckpoint.moveAfter(c)).thenReturn(cCheckpoint); - when(bCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(bCheckpoint.getShardId()).thenReturn(SHARD_ID); - when(cCheckpoint.moveAfter(d)).thenReturn(dCheckpoint); - when(cCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(cCheckpoint.getShardId()).thenReturn(SHARD_ID); - when(dCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(dCheckpoint.getShardId()).thenReturn(SHARD_ID); - - when(kinesisClient.getRecords(INITIAL_ITERATOR, STREAM_NAME, SHARD_ID)).thenReturn(firstResult); - when(kinesisClient.getRecords(SECOND_ITERATOR, STREAM_NAME, SHARD_ID)).thenReturn(secondResult); - when(kinesisClient.getRecords(THIRD_ITERATOR, STREAM_NAME, SHARD_ID)).thenReturn(thirdResult); - - when(firstResult.getNextShardIterator()).thenReturn(SECOND_ITERATOR); - when(secondResult.getNextShardIterator()).thenReturn(THIRD_ITERATOR); - when(thirdResult.getNextShardIterator()).thenReturn(THIRD_ITERATOR); - - when(firstResult.getRecords()).thenReturn(Collections.emptyList()); - when(secondResult.getRecords()).thenReturn(Collections.emptyList()); - when(thirdResult.getRecords()).thenReturn(Collections.emptyList()); - - when(recordFilter.apply(anyList(), any(ShardCheckpoint.class))) - .thenAnswer(new IdentityAnswer()); - - WatermarkPolicyFactory watermarkPolicyFactory = WatermarkPolicyFactory.withArrivalTimePolicy(); - iterator = - new ShardRecordsIterator( - firstCheckpoint, kinesisClient, watermarkPolicyFactory, recordFilter); - } - - @Test - public void goesThroughAvailableRecords() - throws IOException, TransientKinesisException, KinesisShardClosedException { - when(firstResult.getRecords()).thenReturn(asList(a, b, c)); - when(secondResult.getRecords()).thenReturn(singletonList(d)); - when(thirdResult.getRecords()).thenReturn(Collections.emptyList()); - - assertThat(iterator.getCheckpoint()).isEqualTo(firstCheckpoint); - assertThat(iterator.readNextBatch()).isEqualTo(asList(a, b, c)); - assertThat(iterator.readNextBatch()).isEqualTo(singletonList(d)); - assertThat(iterator.readNextBatch()).isEqualTo(Collections.emptyList()); - } - - @Test - public void conformingRecordsMovesCheckpoint() throws IOException, TransientKinesisException { - when(firstResult.getRecords()).thenReturn(asList(a, b, c)); - when(secondResult.getRecords()).thenReturn(singletonList(d)); - when(thirdResult.getRecords()).thenReturn(Collections.emptyList()); - - when(a.getApproximateArrivalTimestamp()).thenReturn(NOW); - when(b.getApproximateArrivalTimestamp()).thenReturn(NOW.plus(Duration.standardSeconds(1))); - when(c.getApproximateArrivalTimestamp()).thenReturn(NOW.plus(Duration.standardSeconds(2))); - when(d.getApproximateArrivalTimestamp()).thenReturn(NOW.plus(Duration.standardSeconds(3))); - - iterator.ackRecord(a); - assertThat(iterator.getCheckpoint()).isEqualTo(aCheckpoint); - iterator.ackRecord(b); - assertThat(iterator.getCheckpoint()).isEqualTo(bCheckpoint); - iterator.ackRecord(c); - assertThat(iterator.getCheckpoint()).isEqualTo(cCheckpoint); - iterator.ackRecord(d); - assertThat(iterator.getCheckpoint()).isEqualTo(dCheckpoint); - } - - @Test - public void refreshesExpiredIterator() - throws IOException, TransientKinesisException, KinesisShardClosedException { - when(firstResult.getRecords()).thenReturn(singletonList(a)); - when(secondResult.getRecords()).thenReturn(singletonList(b)); - - when(a.getApproximateArrivalTimestamp()).thenReturn(NOW); - when(b.getApproximateArrivalTimestamp()).thenReturn(NOW.plus(Duration.standardSeconds(1))); - - when(kinesisClient.getRecords(SECOND_ITERATOR, STREAM_NAME, SHARD_ID)) - .thenThrow(ExpiredIteratorException.class); - when(aCheckpoint.getShardIterator(kinesisClient)).thenReturn(SECOND_REFRESHED_ITERATOR); - when(kinesisClient.getRecords(SECOND_REFRESHED_ITERATOR, STREAM_NAME, SHARD_ID)) - .thenReturn(secondResult); - - assertThat(iterator.readNextBatch()).isEqualTo(singletonList(a)); - iterator.ackRecord(a); - assertThat(iterator.readNextBatch()).isEqualTo(singletonList(b)); - assertThat(iterator.readNextBatch()).isEqualTo(Collections.emptyList()); - } - - @Test - public void tracksLatestRecordTimestamp() { - when(firstResult.getRecords()).thenReturn(singletonList(a)); - when(secondResult.getRecords()).thenReturn(asList(b, c)); - when(thirdResult.getRecords()).thenReturn(singletonList(c)); - - when(a.getApproximateArrivalTimestamp()).thenReturn(NOW); - when(b.getApproximateArrivalTimestamp()).thenReturn(NOW.plus(Duration.standardSeconds(4))); - when(c.getApproximateArrivalTimestamp()).thenReturn(NOW.plus(Duration.standardSeconds(2))); - when(d.getApproximateArrivalTimestamp()).thenReturn(NOW.plus(Duration.standardSeconds(6))); - - iterator.ackRecord(a); - assertThat(iterator.getLatestRecordTimestamp()).isEqualTo(NOW); - iterator.ackRecord(b); - assertThat(iterator.getLatestRecordTimestamp()) - .isEqualTo(NOW.plus(Duration.standardSeconds(4))); - iterator.ackRecord(c); - assertThat(iterator.getLatestRecordTimestamp()) - .isEqualTo(NOW.plus(Duration.standardSeconds(4))); - iterator.ackRecord(d); - assertThat(iterator.getLatestRecordTimestamp()) - .isEqualTo(NOW.plus(Duration.standardSeconds(6))); - } - - private static class IdentityAnswer implements Answer { - - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - return invocation.getArguments()[0]; - } - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java deleted file mode 100644 index 4a7fed20af98..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java +++ /dev/null @@ -1,614 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; - -import com.amazonaws.AmazonServiceException; -import com.amazonaws.AmazonServiceException.ErrorType; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.model.Datapoint; -import com.amazonaws.services.cloudwatch.model.GetMetricStatisticsRequest; -import com.amazonaws.services.cloudwatch.model.GetMetricStatisticsResult; -import com.amazonaws.services.kinesis.AmazonKinesis; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.amazonaws.services.kinesis.model.DescribeStreamSummaryRequest; -import com.amazonaws.services.kinesis.model.DescribeStreamSummaryResult; -import com.amazonaws.services.kinesis.model.ExpiredIteratorException; -import com.amazonaws.services.kinesis.model.GetRecordsRequest; -import com.amazonaws.services.kinesis.model.GetRecordsResult; -import com.amazonaws.services.kinesis.model.GetShardIteratorRequest; -import com.amazonaws.services.kinesis.model.GetShardIteratorResult; -import com.amazonaws.services.kinesis.model.LimitExceededException; -import com.amazonaws.services.kinesis.model.ListShardsRequest; -import com.amazonaws.services.kinesis.model.ListShardsResult; -import com.amazonaws.services.kinesis.model.ProvisionedThroughputExceededException; -import com.amazonaws.services.kinesis.model.Record; -import com.amazonaws.services.kinesis.model.Shard; -import com.amazonaws.services.kinesis.model.ShardFilter; -import com.amazonaws.services.kinesis.model.ShardFilterType; -import com.amazonaws.services.kinesis.model.ShardIteratorType; -import com.amazonaws.services.kinesis.model.StreamDescriptionSummary; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.function.Supplier; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.joda.time.Minutes; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; -import org.mockito.stubbing.Answer; - -/** * */ -@RunWith(MockitoJUnitRunner.class) -public class SimplifiedKinesisClientTest { - - private static final String STREAM = "stream"; - private static final String SHARD_1 = "shard-01"; - private static final String SHARD_2 = "shard-02"; - private static final String SHARD_3 = "shard-03"; - private static final String SHARD_ITERATOR = "iterator"; - private static final String SEQUENCE_NUMBER = "abc123"; - private static final Instant CURRENT_TIMESTAMP = Instant.parse("2000-01-01T15:00:00.000Z"); - - @Mock private AmazonKinesis kinesis; - @Mock private AmazonCloudWatch cloudWatch; - @Mock private Supplier currentInstantSupplier; - @InjectMocks private SimplifiedKinesisClient underTest; - - @Test - public void shouldReturnIteratorStartingWithSequenceNumber() throws Exception { - when(kinesis.getShardIterator( - new GetShardIteratorRequest() - .withStreamName(STREAM) - .withShardId(SHARD_1) - .withShardIteratorType(ShardIteratorType.AT_SEQUENCE_NUMBER) - .withStartingSequenceNumber(SEQUENCE_NUMBER))) - .thenReturn(new GetShardIteratorResult().withShardIterator(SHARD_ITERATOR)); - - String stream = - underTest.getShardIterator( - STREAM, SHARD_1, ShardIteratorType.AT_SEQUENCE_NUMBER, SEQUENCE_NUMBER, null); - - assertThat(stream).isEqualTo(SHARD_ITERATOR); - } - - @Test - public void shouldReturnIteratorStartingWithTimestamp() throws Exception { - Instant timestamp = Instant.now(); - when(kinesis.getShardIterator( - new GetShardIteratorRequest() - .withStreamName(STREAM) - .withShardId(SHARD_1) - .withShardIteratorType(ShardIteratorType.AT_SEQUENCE_NUMBER) - .withTimestamp(timestamp.toDate()))) - .thenReturn(new GetShardIteratorResult().withShardIterator(SHARD_ITERATOR)); - - String stream = - underTest.getShardIterator( - STREAM, SHARD_1, ShardIteratorType.AT_SEQUENCE_NUMBER, null, timestamp); - - assertThat(stream).isEqualTo(SHARD_ITERATOR); - } - - @Test - public void shouldHandleExpiredIterationExceptionForGetShardIterator() { - shouldHandleGetShardIteratorError( - new ExpiredIteratorException(""), ExpiredIteratorException.class); - } - - @Test - public void shouldHandleLimitExceededExceptionForGetShardIterator() { - shouldHandleGetShardIteratorError( - new LimitExceededException(""), KinesisClientThrottledException.class); - } - - @Test - public void shouldHandleProvisionedThroughputExceededExceptionForGetShardIterator() { - shouldHandleGetShardIteratorError( - new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class); - } - - @Test - public void shouldHandleServiceErrorForGetShardIterator() { - shouldHandleGetShardIteratorError( - newAmazonServiceException(ErrorType.Service), TransientKinesisException.class); - } - - @Test - public void shouldHandleClientErrorForGetShardIterator() { - shouldHandleGetShardIteratorError( - newAmazonServiceException(ErrorType.Client), RuntimeException.class); - } - - @Test - public void shouldHandleUnexpectedExceptionForGetShardIterator() { - shouldHandleGetShardIteratorError(new NullPointerException(), RuntimeException.class); - } - - private void shouldHandleGetShardIteratorError( - Exception thrownException, Class expectedExceptionClass) { - GetShardIteratorRequest request = - new GetShardIteratorRequest() - .withStreamName(STREAM) - .withShardId(SHARD_1) - .withShardIteratorType(ShardIteratorType.LATEST); - - when(kinesis.getShardIterator(request)).thenThrow(thrownException); - - try { - underTest.getShardIterator(STREAM, SHARD_1, ShardIteratorType.LATEST, null, null); - failBecauseExceptionWasNotThrown(expectedExceptionClass); - } catch (Exception e) { - assertThat(e).isExactlyInstanceOf(expectedExceptionClass); - } finally { - reset(kinesis); - } - } - - @Test - public void shouldListAllShardsForTrimHorizon() throws Exception { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - - ShardFilter shardFilter = new ShardFilter().withType(ShardFilterType.AT_TRIM_HORIZON); - - when(kinesis.listShards( - new ListShardsRequest() - .withStreamName(STREAM) - .withShardFilter(shardFilter) - .withMaxResults(1_000))) - .thenReturn(new ListShardsResult().withShards(shard1, shard2, shard3).withNextToken(null)); - - List shards = - underTest.listShardsAtPoint( - STREAM, new StartingPoint(InitialPositionInStream.TRIM_HORIZON)); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void shouldListAllShardsForTrimHorizonWithPagedResults() throws Exception { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - - ShardFilter shardFilter = new ShardFilter().withType(ShardFilterType.AT_TRIM_HORIZON); - - String nextListShardsToken = "testNextToken"; - when(kinesis.listShards( - new ListShardsRequest() - .withStreamName(STREAM) - .withShardFilter(shardFilter) - .withMaxResults(1_000))) - .thenReturn( - new ListShardsResult().withShards(shard1, shard2).withNextToken(nextListShardsToken)); - - when(kinesis.listShards( - new ListShardsRequest() - .withMaxResults(1_000) - .withShardFilter(shardFilter) - .withNextToken(nextListShardsToken))) - .thenReturn(new ListShardsResult().withShards(shard3).withNextToken(null)); - - List shards = - underTest.listShardsAtPoint( - STREAM, new StartingPoint(InitialPositionInStream.TRIM_HORIZON)); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void shouldListAllShardsForTimestampWithinStreamRetentionAfterStreamCreationTimestamp() - throws Exception { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - - int hoursDifference = 1; - int retentionPeriodHours = hoursDifference * 3; - Instant streamCreationTimestamp = - CURRENT_TIMESTAMP.minus(Duration.standardHours(retentionPeriodHours)); - Instant startingPointTimestamp = - streamCreationTimestamp.plus(Duration.standardHours(hoursDifference)); - - when(currentInstantSupplier.get()).thenReturn(CURRENT_TIMESTAMP); - - when(kinesis.describeStreamSummary(new DescribeStreamSummaryRequest().withStreamName(STREAM))) - .thenReturn( - new DescribeStreamSummaryResult() - .withStreamDescriptionSummary( - new StreamDescriptionSummary() - .withRetentionPeriodHours(retentionPeriodHours) - .withStreamCreationTimestamp(streamCreationTimestamp.toDate()))); - - ShardFilter shardFilter = - new ShardFilter() - .withType(ShardFilterType.AT_TIMESTAMP) - .withTimestamp(startingPointTimestamp.toDate()); - - when(kinesis.listShards( - new ListShardsRequest() - .withStreamName(STREAM) - .withShardFilter(shardFilter) - .withMaxResults(1_000))) - .thenReturn(new ListShardsResult().withShards(shard1, shard2, shard3).withNextToken(null)); - - List shards = - underTest.listShardsAtPoint(STREAM, new StartingPoint(startingPointTimestamp)); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void - shouldListAllShardsForTimestampWithRetriedDescribeStreamSummaryCallAfterStreamCreationTimestamp() - throws TransientKinesisException { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - - int hoursDifference = 1; - int retentionPeriodHours = hoursDifference * 3; - Instant streamCreationTimestamp = - CURRENT_TIMESTAMP.minus(Duration.standardHours(retentionPeriodHours)); - Instant startingPointTimestamp = - streamCreationTimestamp.plus(Duration.standardHours(hoursDifference)); - - when(currentInstantSupplier.get()).thenReturn(CURRENT_TIMESTAMP); - - when(kinesis.describeStreamSummary(new DescribeStreamSummaryRequest().withStreamName(STREAM))) - .thenThrow(new LimitExceededException("Fake Exception: Limit exceeded")) - .thenReturn( - new DescribeStreamSummaryResult() - .withStreamDescriptionSummary( - new StreamDescriptionSummary() - .withRetentionPeriodHours(retentionPeriodHours) - .withStreamCreationTimestamp(streamCreationTimestamp.toDate()))); - - ShardFilter shardFilter = - new ShardFilter() - .withType(ShardFilterType.AT_TIMESTAMP) - .withTimestamp(startingPointTimestamp.toDate()); - - when(kinesis.listShards( - new ListShardsRequest() - .withStreamName(STREAM) - .withShardFilter(shardFilter) - .withMaxResults(1_000))) - .thenReturn(new ListShardsResult().withShards(shard1, shard2, shard3).withNextToken(null)); - - List shards = - underTest.listShardsAtPoint(STREAM, new StartingPoint(startingPointTimestamp)); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void shouldListAllShardsForTimestampOutsideStreamRetentionAfterStreamCreationTimestamp() - throws Exception { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - - int retentionPeriodHours = 3; - int startingPointHours = 5; - int hoursSinceStreamCreation = 6; - - Instant streamCreationTimestamp = - CURRENT_TIMESTAMP.minus(Duration.standardHours(hoursSinceStreamCreation)); - Instant startingPointTimestampAfterStreamRetentionTimestamp = - CURRENT_TIMESTAMP.minus(Duration.standardHours(startingPointHours)); - - when(currentInstantSupplier.get()).thenReturn(CURRENT_TIMESTAMP); - - DescribeStreamSummaryRequest describeStreamRequest = - new DescribeStreamSummaryRequest().withStreamName(STREAM); - when(kinesis.describeStreamSummary(describeStreamRequest)) - .thenReturn( - new DescribeStreamSummaryResult() - .withStreamDescriptionSummary( - new StreamDescriptionSummary() - .withRetentionPeriodHours(retentionPeriodHours) - .withStreamCreationTimestamp(streamCreationTimestamp.toDate()))); - - ShardFilter shardFilter = new ShardFilter().withType(ShardFilterType.AT_TRIM_HORIZON); - - when(kinesis.listShards( - new ListShardsRequest() - .withStreamName(STREAM) - .withShardFilter(shardFilter) - .withMaxResults(1_000))) - .thenReturn(new ListShardsResult().withShards(shard1, shard2, shard3).withNextToken(null)); - - List shards = - underTest.listShardsAtPoint( - STREAM, new StartingPoint(startingPointTimestampAfterStreamRetentionTimestamp)); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void shouldListAllShardsForTimestampBeforeStreamCreationTimestamp() throws Exception { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - - Instant startingPointTimestamp = Instant.parse("2000-01-01T15:00:00.000Z"); - Instant streamCreationTimestamp = startingPointTimestamp.plus(Duration.standardHours(1)); - - DescribeStreamSummaryRequest describeStreamRequest = - new DescribeStreamSummaryRequest().withStreamName(STREAM); - when(kinesis.describeStreamSummary(describeStreamRequest)) - .thenReturn( - new DescribeStreamSummaryResult() - .withStreamDescriptionSummary( - new StreamDescriptionSummary() - .withStreamCreationTimestamp(streamCreationTimestamp.toDate()))); - - ShardFilter shardFilter = new ShardFilter().withType(ShardFilterType.AT_TRIM_HORIZON); - - when(kinesis.listShards( - new ListShardsRequest() - .withStreamName(STREAM) - .withShardFilter(shardFilter) - .withMaxResults(1_000))) - .thenReturn(new ListShardsResult().withShards(shard1, shard2, shard3).withNextToken(null)); - - List shards = - underTest.listShardsAtPoint(STREAM, new StartingPoint(startingPointTimestamp)); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void shouldListAllShardsForLatest() throws Exception { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - - when(kinesis.listShards( - new ListShardsRequest() - .withStreamName(STREAM) - .withShardFilter(new ShardFilter().withType(ShardFilterType.AT_LATEST)) - .withMaxResults(1_000))) - .thenReturn(new ListShardsResult().withShards(shard1, shard2, shard3).withNextToken(null)); - - List shards = - underTest.listShardsAtPoint(STREAM, new StartingPoint(InitialPositionInStream.LATEST)); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void shouldListAllShardsForExclusiveStartShardId() throws Exception { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - - String exclusiveStartShardId = "exclusiveStartShardId"; - - when(kinesis.listShards( - new ListShardsRequest() - .withStreamName(STREAM) - .withMaxResults(1_000) - .withShardFilter( - new ShardFilter() - .withType(ShardFilterType.AFTER_SHARD_ID) - .withShardId(exclusiveStartShardId)))) - .thenReturn(new ListShardsResult().withShards(shard1, shard2, shard3).withNextToken(null)); - - List shards = underTest.listShardsFollowingClosedShard(STREAM, exclusiveStartShardId); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void shouldHandleExpiredIterationExceptionForShardListing() { - shouldHandleShardListingError(new ExpiredIteratorException(""), ExpiredIteratorException.class); - } - - @Test - public void shouldHandleLimitExceededExceptionForShardListing() { - shouldHandleShardListingError( - new LimitExceededException(""), KinesisClientThrottledException.class); - } - - @Test - public void shouldHandleProvisionedThroughputExceededExceptionForShardListing() { - shouldHandleShardListingError( - new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class); - } - - @Test - public void shouldHandleServiceErrorForShardListing() { - shouldHandleShardListingError( - newAmazonServiceException(ErrorType.Service), TransientKinesisException.class); - } - - @Test - public void shouldHandleClientErrorForShardListing() { - shouldHandleShardListingError( - newAmazonServiceException(ErrorType.Client), RuntimeException.class); - } - - @Test - public void shouldHandleUnexpectedExceptionForShardListing() { - shouldHandleShardListingError(new NullPointerException(), RuntimeException.class); - } - - private void shouldHandleShardListingError( - Exception thrownException, Class expectedExceptionClass) { - when(kinesis.listShards(any(ListShardsRequest.class))).thenThrow(thrownException); - try { - underTest.listShardsAtPoint(STREAM, new StartingPoint(InitialPositionInStream.TRIM_HORIZON)); - failBecauseExceptionWasNotThrown(expectedExceptionClass); - } catch (Exception e) { - assertThat(e).isExactlyInstanceOf(expectedExceptionClass); - } finally { - reset(kinesis); - } - } - - @Test - public void shouldCountBytesWhenSingleDataPointReturned() throws Exception { - Instant countSince = new Instant("2017-04-06T10:00:00.000Z"); - Instant countTo = new Instant("2017-04-06T11:00:00.000Z"); - Minutes periodTime = Minutes.minutesBetween(countSince, countTo); - GetMetricStatisticsRequest metricStatisticsRequest = - underTest.createMetricStatisticsRequest(STREAM, countSince, countTo, periodTime); - GetMetricStatisticsResult result = - new GetMetricStatisticsResult().withDatapoints(new Datapoint().withSum(1.0)); - - when(cloudWatch.getMetricStatistics(metricStatisticsRequest)).thenReturn(result); - - long backlogBytes = underTest.getBacklogBytes(STREAM, countSince, countTo); - - assertThat(backlogBytes).isEqualTo(1L); - } - - @Test - public void shouldCountBytesWhenMultipleDataPointsReturned() throws Exception { - Instant countSince = new Instant("2017-04-06T10:00:00.000Z"); - Instant countTo = new Instant("2017-04-06T11:00:00.000Z"); - Minutes periodTime = Minutes.minutesBetween(countSince, countTo); - GetMetricStatisticsRequest metricStatisticsRequest = - underTest.createMetricStatisticsRequest(STREAM, countSince, countTo, periodTime); - GetMetricStatisticsResult result = - new GetMetricStatisticsResult() - .withDatapoints( - new Datapoint().withSum(1.0), - new Datapoint().withSum(3.0), - new Datapoint().withSum(2.0)); - - when(cloudWatch.getMetricStatistics(metricStatisticsRequest)).thenReturn(result); - - long backlogBytes = underTest.getBacklogBytes(STREAM, countSince, countTo); - - assertThat(backlogBytes).isEqualTo(6L); - } - - @Test - public void shouldNotCallCloudWatchWhenSpecifiedPeriodTooShort() throws Exception { - Instant countSince = new Instant("2017-04-06T10:00:00.000Z"); - Instant countTo = new Instant("2017-04-06T10:00:02.000Z"); - - long backlogBytes = underTest.getBacklogBytes(STREAM, countSince, countTo); - - assertThat(backlogBytes).isEqualTo(0L); - verifyZeroInteractions(cloudWatch); - } - - @Test - public void shouldHandleLimitExceededExceptionForGetBacklogBytes() { - shouldHandleGetBacklogBytesError( - new LimitExceededException(""), KinesisClientThrottledException.class); - } - - @Test - public void shouldHandleProvisionedThroughputExceededExceptionForGetBacklogBytes() { - shouldHandleGetBacklogBytesError( - new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class); - } - - @Test - public void shouldHandleServiceErrorForGetBacklogBytes() { - shouldHandleGetBacklogBytesError( - newAmazonServiceException(ErrorType.Service), TransientKinesisException.class); - } - - @Test - public void shouldHandleClientErrorForGetBacklogBytes() { - shouldHandleGetBacklogBytesError( - newAmazonServiceException(ErrorType.Client), RuntimeException.class); - } - - @Test - public void shouldHandleUnexpectedExceptionForGetBacklogBytes() { - shouldHandleGetBacklogBytesError(new NullPointerException(), RuntimeException.class); - } - - private void shouldHandleGetBacklogBytesError( - Exception thrownException, Class expectedExceptionClass) { - Instant countSince = new Instant("2017-04-06T10:00:00.000Z"); - Instant countTo = new Instant("2017-04-06T11:00:00.000Z"); - Minutes periodTime = Minutes.minutesBetween(countSince, countTo); - GetMetricStatisticsRequest metricStatisticsRequest = - underTest.createMetricStatisticsRequest(STREAM, countSince, countTo, periodTime); - - when(cloudWatch.getMetricStatistics(metricStatisticsRequest)).thenThrow(thrownException); - try { - underTest.getBacklogBytes(STREAM, countSince, countTo); - failBecauseExceptionWasNotThrown(expectedExceptionClass); - } catch (Exception e) { - assertThat(e).isExactlyInstanceOf(expectedExceptionClass); - } finally { - reset(kinesis); - } - } - - private AmazonServiceException newAmazonServiceException(ErrorType errorType) { - AmazonServiceException exception = new AmazonServiceException(""); - exception.setErrorType(errorType); - return exception; - } - - @Test - public void shouldReturnLimitedNumberOfRecords() throws Exception { - final Integer limit = 100; - - doAnswer( - (Answer) - invocation -> { - GetRecordsRequest request = (GetRecordsRequest) invocation.getArguments()[0]; - List records = generateRecords(request.getLimit()); - return new GetRecordsResult().withRecords(records).withMillisBehindLatest(1000L); - }) - .when(kinesis) - .getRecords(any(GetRecordsRequest.class)); - - GetKinesisRecordsResult result = underTest.getRecords(SHARD_ITERATOR, STREAM, SHARD_1, limit); - assertThat(result.getRecords().size()).isEqualTo(limit); - } - - private List generateRecords(int num) { - List records = new ArrayList<>(); - for (int i = 0; i < num; i++) { - byte[] value = new byte[1024]; - Arrays.fill(value, (byte) i); - records.add( - new Record() - .withSequenceNumber(String.valueOf(i)) - .withPartitionKey("key") - .withData(ByteBuffer.wrap(value))); - } - return records; - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicyTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicyTest.java deleted file mode 100644 index ce5c555a4dfb..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/WatermarkPolicyTest.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.powermock.api.mockito.PowerMockito.mockStatic; - -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -/** Tests {@link WatermarkPolicy}. */ -@RunWith(PowerMockRunner.class) -@PrepareForTest(Instant.class) -public class WatermarkPolicyTest { - private static final Instant NOW = Instant.now(); - - @Test - public void shouldAdvanceWatermarkWithTheArrivalTimeFromKinesisRecords() { - WatermarkPolicy policy = WatermarkPolicyFactory.withArrivalTimePolicy().createWatermarkPolicy(); - - KinesisRecord a = mock(KinesisRecord.class); - KinesisRecord b = mock(KinesisRecord.class); - - Instant time1 = NOW.minus(Duration.standardSeconds(30L)); - Instant time2 = NOW.minus(Duration.standardSeconds(20L)); - when(a.getApproximateArrivalTimestamp()).thenReturn(time1); - when(b.getApproximateArrivalTimestamp()).thenReturn(time2); - - policy.update(a); - assertThat(policy.getWatermark()).isEqualTo(time1); - policy.update(b); - assertThat(policy.getWatermark()).isEqualTo(time2); - } - - @Test - public void shouldOnlyAdvanceTheWatermark() { - WatermarkPolicy policy = WatermarkPolicyFactory.withArrivalTimePolicy().createWatermarkPolicy(); - - KinesisRecord a = mock(KinesisRecord.class); - KinesisRecord b = mock(KinesisRecord.class); - KinesisRecord c = mock(KinesisRecord.class); - - Instant time1 = NOW.minus(Duration.standardSeconds(30L)); - Instant time2 = NOW.minus(Duration.standardSeconds(20L)); - Instant time3 = NOW.minus(Duration.standardSeconds(40L)); - when(a.getApproximateArrivalTimestamp()).thenReturn(time1); - when(b.getApproximateArrivalTimestamp()).thenReturn(time2); - // time3 is before time2 - when(c.getApproximateArrivalTimestamp()).thenReturn(time3); - - policy.update(a); - assertThat(policy.getWatermark()).isEqualTo(time1); - policy.update(b); - assertThat(policy.getWatermark()).isEqualTo(time2); - policy.update(c); - // watermark doesn't go back in time - assertThat(policy.getWatermark()).isEqualTo(time2); - } - - @Test - public void shouldAdvanceWatermarkWhenThereAreNoIncomingRecords() { - WatermarkParameters standardWatermarkParams = WatermarkParameters.create(); - WatermarkPolicy policy = - WatermarkPolicyFactory.withCustomWatermarkPolicy(standardWatermarkParams) - .createWatermarkPolicy(); - - mockStatic(Instant.class); - - Instant time1 = NOW.minus(Duration.standardSeconds(500)); // returned when update is called - Instant time2 = - NOW.minus( - Duration.standardSeconds(498)); // returned when getWatermark is called the first time - Instant time3 = NOW; // returned when getWatermark is called the second time - Instant arrivalTime = NOW.minus(Duration.standardSeconds(510)); - Duration watermarkIdleTimeThreshold = - standardWatermarkParams.getWatermarkIdleDurationThreshold(); - - when(Instant.now()).thenReturn(time1).thenReturn(time2).thenReturn(time3); - - KinesisRecord a = mock(KinesisRecord.class); - when(a.getApproximateArrivalTimestamp()).thenReturn(arrivalTime); - - policy.update(a); - - // returns the latest event time when the watermark - assertThat(policy.getWatermark()).isEqualTo(arrivalTime); - // advance the watermark to [NOW - watermark idle time threshold] - assertThat(policy.getWatermark()).isEqualTo(time3.minus(watermarkIdleTimeThreshold)); - } - - @Test - public void shouldAdvanceWatermarkToNowWithProcessingTimePolicy() { - WatermarkPolicy policy = - WatermarkPolicyFactory.withProcessingTimePolicy().createWatermarkPolicy(); - - mockStatic(Instant.class); - - Instant time1 = NOW.minus(Duration.standardSeconds(5)); - Instant time2 = NOW.minus(Duration.standardSeconds(4)); - - when(Instant.now()).thenReturn(time1).thenReturn(time2); - - assertThat(policy.getWatermark()).isEqualTo(time1); - assertThat(policy.getWatermark()).isEqualTo(time2); - } - - @Test - public void shouldAdvanceWatermarkWithCustomTimePolicy() { - SerializableFunction timestampFn = - (record) -> record.getApproximateArrivalTimestamp().plus(Duration.standardMinutes(1)); - - WatermarkPolicy policy = - WatermarkPolicyFactory.withCustomWatermarkPolicy( - WatermarkParameters.create().withTimestampFn(timestampFn)) - .createWatermarkPolicy(); - - KinesisRecord a = mock(KinesisRecord.class); - KinesisRecord b = mock(KinesisRecord.class); - - Instant time1 = NOW.minus(Duration.standardSeconds(30L)); - Instant time2 = NOW.minus(Duration.standardSeconds(20L)); - when(a.getApproximateArrivalTimestamp()).thenReturn(time1); - when(b.getApproximateArrivalTimestamp()).thenReturn(time2); - - policy.update(a); - assertThat(policy.getWatermark()).isEqualTo(time1.plus(Duration.standardMinutes(1))); - policy.update(b); - assertThat(policy.getWatermark()).isEqualTo(time2.plus(Duration.standardMinutes(1))); - } - - @Test - public void shouldUpdateWatermarkParameters() { - SerializableFunction fn = input -> Instant.now(); - Duration idleDurationThreshold = Duration.standardSeconds(30); - - WatermarkParameters parameters = - WatermarkParameters.create() - .withTimestampFn(fn) - .withWatermarkIdleDurationThreshold(idleDurationThreshold); - - assertThat(parameters.getTimestampFn()).isEqualTo(fn); - assertThat(parameters.getWatermarkIdleDurationThreshold()).isEqualTo(idleDurationThreshold); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/serde/AwsModuleTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/serde/AwsModuleTest.java deleted file mode 100644 index e58825ec4b9d..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/serde/AwsModuleTest.java +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis.serde; - -import static org.apache.commons.lang3.reflect.FieldUtils.readField; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.hasItem; -import static org.hamcrest.Matchers.not; -import static org.junit.Assert.assertEquals; - -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSSessionCredentials; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.BasicSessionCredentials; -import com.amazonaws.auth.ClasspathPropertiesFileCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.auth.EC2ContainerCredentialsProviderWrapper; -import com.amazonaws.auth.EnvironmentVariableCredentialsProvider; -import com.amazonaws.auth.PropertiesFileCredentialsProvider; -import com.amazonaws.auth.SystemPropertiesCredentialsProvider; -import com.amazonaws.auth.profile.ProfileCredentialsProvider; -import com.fasterxml.jackson.databind.Module; -import com.fasterxml.jackson.databind.ObjectMapper; -import java.util.List; -import org.apache.beam.sdk.util.common.ReflectHelpers; -import org.hamcrest.Matchers; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests {@link AwsModule}. */ -@RunWith(JUnit4.class) -public class AwsModuleTest { - private static final String ACCESS_KEY_ID = "ACCESS_KEY_ID"; - private static final String SECRET_ACCESS_KEY = "SECRET_ACCESS_KEY"; - private static final String SESSION_TOKEN = "SESSION_TOKEN"; - - private final ObjectMapper objectMapper = new ObjectMapper().registerModule(new AwsModule()); - - private String serialize(Object obj) throws Exception { - return objectMapper.writeValueAsString(obj); - } - - private T deserialize(String serializedObj, Class clazz) throws Exception { - return objectMapper.readValue(serializedObj, clazz); - } - - private AWSCredentialsProvider deserializeCredentialsProvider(String serializedProvider) - throws Exception { - return deserialize(serializedProvider, AWSCredentialsProvider.class); - } - - @Test - public void testObjectMapperCannotFindModule() { - // module shall not be discoverable to not conflict with the one in amazon-web-services - List modules = ObjectMapper.findModules(ReflectHelpers.findClassLoader()); - assertThat(modules, not(hasItem(Matchers.instanceOf(AwsModule.class)))); - } - - private void checkStaticBasicCredentials(AWSCredentialsProvider provider) { - assertEquals(AWSStaticCredentialsProvider.class, provider.getClass()); - assertEquals(ACCESS_KEY_ID, provider.getCredentials().getAWSAccessKeyId()); - assertEquals(SECRET_ACCESS_KEY, provider.getCredentials().getAWSSecretKey()); - } - - private void checkStaticSessionCredentials(AWSCredentialsProvider provider) { - checkStaticBasicCredentials(provider); - assertEquals( - SESSION_TOKEN, ((AWSSessionCredentials) provider.getCredentials()).getSessionToken()); - } - - @Test - public void testAWSStaticCredentialsProviderSerializationDeserialization() throws Exception { - AWSCredentialsProvider credentialsProvider = - new AWSStaticCredentialsProvider(new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY)); - - String serializedCredentialsProvider = serialize(credentialsProvider); - AWSCredentialsProvider deserializedCredentialsProvider = - deserializeCredentialsProvider(serializedCredentialsProvider); - - checkStaticBasicCredentials(deserializedCredentialsProvider); - - credentialsProvider = - new AWSStaticCredentialsProvider( - new BasicSessionCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, SESSION_TOKEN)); - - checkStaticSessionCredentials(credentialsProvider); - } - - @Test - public void testPropertiesFileCredentialsProviderSerializationDeserialization() throws Exception { - String credentialsFilePath = "/path/to/file"; - - AWSCredentialsProvider credentialsProvider = - new PropertiesFileCredentialsProvider(credentialsFilePath); - - String serializedCredentialsProvider = serialize(credentialsProvider); - AWSCredentialsProvider deserializedCredentialsProvider = - deserializeCredentialsProvider(serializedCredentialsProvider); - - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - assertEquals( - credentialsFilePath, - readField(deserializedCredentialsProvider, "credentialsFilePath", true)); - } - - @Test - public void testClasspathPropertiesFileCredentialsProviderSerializationDeserialization() - throws Exception { - String credentialsFilePath = "/path/to/file"; - - AWSCredentialsProvider credentialsProvider = - new ClasspathPropertiesFileCredentialsProvider(credentialsFilePath); - - String serializedCredentialsProvider = serialize(credentialsProvider); - AWSCredentialsProvider deserializedCredentialsProvider = - deserializeCredentialsProvider(serializedCredentialsProvider); - - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - assertEquals( - credentialsFilePath, - readField(deserializedCredentialsProvider, "credentialsFilePath", true)); - } - - @Test - public void testSingletonAWSCredentialsProviderSerializationDeserialization() throws Exception { - AWSCredentialsProvider credentialsProvider; - String serializedCredentialsProvider; - AWSCredentialsProvider deserializedCredentialsProvider; - - credentialsProvider = new DefaultAWSCredentialsProviderChain(); - serializedCredentialsProvider = serialize(credentialsProvider); - deserializedCredentialsProvider = deserializeCredentialsProvider(serializedCredentialsProvider); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - - credentialsProvider = new EnvironmentVariableCredentialsProvider(); - serializedCredentialsProvider = serialize(credentialsProvider); - deserializedCredentialsProvider = deserializeCredentialsProvider(serializedCredentialsProvider); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - - credentialsProvider = new SystemPropertiesCredentialsProvider(); - serializedCredentialsProvider = serialize(credentialsProvider); - deserializedCredentialsProvider = deserializeCredentialsProvider(serializedCredentialsProvider); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - - credentialsProvider = new ProfileCredentialsProvider(); - serializedCredentialsProvider = serialize(credentialsProvider); - deserializedCredentialsProvider = deserializeCredentialsProvider(serializedCredentialsProvider); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - - credentialsProvider = new EC2ContainerCredentialsProviderWrapper(); - serializedCredentialsProvider = serialize(credentialsProvider); - deserializedCredentialsProvider = deserializeCredentialsProvider(serializedCredentialsProvider); - assertEquals(credentialsProvider.getClass(), deserializedCredentialsProvider.getClass()); - } -} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/serde/AwsSerializableUtilsTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/serde/AwsSerializableUtilsTest.java deleted file mode 100644 index 972912be2a94..000000000000 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/serde/AwsSerializableUtilsTest.java +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.kinesis.serde; - -import static org.apache.beam.sdk.io.kinesis.serde.AwsSerializableUtils.deserialize; -import static org.apache.beam.sdk.io.kinesis.serde.AwsSerializableUtils.serialize; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSSessionCredentials; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.BasicSessionCredentials; -import com.amazonaws.auth.ClasspathPropertiesFileCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.auth.EC2ContainerCredentialsProviderWrapper; -import com.amazonaws.auth.EnvironmentVariableCredentialsProvider; -import com.amazonaws.auth.PropertiesFileCredentialsProvider; -import com.amazonaws.auth.SystemPropertiesCredentialsProvider; -import com.amazonaws.auth.profile.ProfileCredentialsProvider; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class AwsSerializableUtilsTest { - - private static final String ACCESS_KEY_ID = "ACCESS_KEY_ID"; - private static final String SECRET_ACCESS_KEY = "SECRET_ACCESS_KEY"; - private static final String SESSION_TOKEN = "SESSION_TOKEN"; - - private void checkStaticBasicCredentials(AWSCredentialsProvider provider) { - assertTrue(provider instanceof AWSStaticCredentialsProvider); - assertEquals(ACCESS_KEY_ID, provider.getCredentials().getAWSAccessKeyId()); - assertEquals(SECRET_ACCESS_KEY, provider.getCredentials().getAWSSecretKey()); - } - - private void checkStaticSessionCredentials(AWSCredentialsProvider provider) { - checkStaticBasicCredentials(provider); - assertEquals( - SESSION_TOKEN, ((AWSSessionCredentials) provider.getCredentials()).getSessionToken()); - } - - @Test - public void testBasicCredentialsProviderSerialization() { - AWSCredentialsProvider credentialsProvider = - new AWSStaticCredentialsProvider(new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY)); - String serializedProvider = serialize(credentialsProvider); - - checkStaticBasicCredentials(deserialize(serializedProvider)); - } - - @Test - public void testStaticSessionCredentialsProviderSerialization() { - AWSCredentialsProvider credentialsProvider = - new AWSStaticCredentialsProvider( - new BasicSessionCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, SESSION_TOKEN)); - String serializedCredentials = serialize(credentialsProvider); - - checkStaticSessionCredentials(deserialize(serializedCredentials)); - } - - @Test - public void testDefaultAWSCredentialsProviderChainSerialization() { - AWSCredentialsProvider credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); - String expectedSerializedProvider = "{\"@type\":\"DefaultAWSCredentialsProviderChain\"}"; - String serializedProvider = serialize(credentialsProvider); - - assertEquals(expectedSerializedProvider, serializedProvider); - assertEquals(expectedSerializedProvider, serialize(deserialize(serializedProvider))); - } - - @Test - public void testPropertiesFileCredentialsProviderSerialization() { - AWSCredentialsProvider credentialsProvider = - new PropertiesFileCredentialsProvider("AwsCredentials.properties"); - String expectedSerializedProvider = - "{\"@type\":\"PropertiesFileCredentialsProvider\",\"credentialsFilePath\":\"AwsCredentials.properties\"}"; - String serializedProvider = serialize(credentialsProvider); - - assertEquals(expectedSerializedProvider, serializedProvider); - assertEquals(expectedSerializedProvider, serialize(deserialize(serializedProvider))); - } - - @Test - public void testClasspathPropertiesFileCredentialsProviderSerialization() { - AWSCredentialsProvider credentialsProvider = - new ClasspathPropertiesFileCredentialsProvider("AwsCredentials.properties"); - String expectedSerializedProvider = - "{\"@type\":\"ClasspathPropertiesFileCredentialsProvider\",\"credentialsFilePath\":\"/AwsCredentials.properties\"}"; - String serializedProvider = serialize(credentialsProvider); - - assertEquals(expectedSerializedProvider, serializedProvider); - assertEquals(expectedSerializedProvider, serialize(deserialize(serializedProvider))); - } - - @Test - public void testEnvironmentVariableCredentialsProviderSerialization() { - AWSCredentialsProvider credentialsProvider = new EnvironmentVariableCredentialsProvider(); - String expectedSerializedProvider = "{\"@type\":\"EnvironmentVariableCredentialsProvider\"}"; - String serializedProvider = serialize(credentialsProvider); - - assertEquals(expectedSerializedProvider, serializedProvider); - assertEquals(expectedSerializedProvider, serialize(deserialize(serializedProvider))); - } - - @Test - public void testSystemPropertiesCredentialsProviderSerialization() { - AWSCredentialsProvider credentialsProvider = new SystemPropertiesCredentialsProvider(); - String expectedSerializedProvider = "{\"@type\":\"SystemPropertiesCredentialsProvider\"}"; - String serializedProvider = serialize(credentialsProvider); - - assertEquals(expectedSerializedProvider, serializedProvider); - assertEquals(expectedSerializedProvider, serialize(deserialize(serializedProvider))); - } - - @Test - public void testProfileCredentialsProviderSerialization() { - AWSCredentialsProvider credentialsProvider = new ProfileCredentialsProvider(); - String expectedSerializedProvider = "{\"@type\":\"ProfileCredentialsProvider\"}"; - String serializedProvider = serialize(credentialsProvider); - - assertEquals(expectedSerializedProvider, serializedProvider); - assertEquals(expectedSerializedProvider, serialize(deserialize(serializedProvider))); - } - - @Test - public void testEC2ContainerCredentialsProviderWrapperSerialization() { - AWSCredentialsProvider credentialsProvider = new EC2ContainerCredentialsProviderWrapper(); - String expectedSerializedProvider = "{\"@type\":\"EC2ContainerCredentialsProviderWrapper\"}"; - String serializedProvider = serialize(credentialsProvider); - - assertEquals(expectedSerializedProvider, serializedProvider); - assertEquals(expectedSerializedProvider, serialize(deserialize(serializedProvider))); - } - - static class UnknownAwsCredentialsProvider implements AWSCredentialsProvider { - @Override - public AWSCredentials getCredentials() { - return new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY); - } - - @Override - public void refresh() {} - } - - @Test(expected = IllegalArgumentException.class) - public void testFailOnAWSCredentialsProviderSerialization() { - AWSCredentialsProvider credentialsProvider = new UnknownAwsCredentialsProvider(); - serialize(credentialsProvider); - } - - @Test(expected = IllegalArgumentException.class) - public void testFailOnAWSCredentialsProviderDeserialization() { - deserialize("invalid string"); - } -} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java index b2196dbf1067..d4c9a3ec6210 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java @@ -102,10 +102,7 @@ public void close() { if (messageReceiver != null) { messageReceiver.close(); } - if (messageProducer != null) { - messageProducer.close(); - } - if (!isClosed()) { + if (jcsmpSession != null) { checkStateNotNull(jcsmpSession).closeSession(); } return 0; @@ -119,8 +116,9 @@ public MessageReceiver getReceiver() { this.messageReceiver = retryCallableManager.retryCallable( this::createFlowReceiver, ImmutableSet.of(JCSMPException.class)); + this.messageReceiver.start(); } - return this.messageReceiver; + return checkStateNotNull(this.messageReceiver); } @Override @@ -138,15 +136,10 @@ public java.util.Queue getPublishedResultsQueue() { return publishedResultsQueue; } - @Override - public boolean isClosed() { - return jcsmpSession == null || jcsmpSession.isClosed(); - } - private MessageProducer createXMLMessageProducer(SubmissionMode submissionMode) throws JCSMPException, IOException { - if (isClosed()) { + if (jcsmpSession == null) { connectWriteSession(submissionMode); } @@ -165,9 +158,6 @@ private MessageProducer createXMLMessageProducer(SubmissionMode submissionMode) } private MessageReceiver createFlowReceiver() throws JCSMPException, IOException { - if (isClosed()) { - connectSession(); - } Queue queue = JCSMPFactory.onlyInstance() diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageReceiver.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageReceiver.java index 95f989bd1be9..017a63260678 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageReceiver.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageReceiver.java @@ -35,13 +35,6 @@ public interface MessageReceiver { */ void start(); - /** - * Returns {@literal true} if the message receiver is closed, {@literal false} otherwise. - * - *

A message receiver is closed when it is no longer able to receive messages. - */ - boolean isClosed(); - /** * Receives a message from the broker. * diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java index 84a876a9d0bc..6dcd0b652616 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java @@ -120,13 +120,6 @@ public abstract class SessionService implements Serializable { /** Gracefully closes the connection to the service. */ public abstract void close(); - /** - * Checks whether the connection to the service is currently closed. This method is called when an - * `UnboundedSolaceReader` is starting to read messages - a session will be created if this - * returns true. - */ - public abstract boolean isClosed(); - /** * Returns a MessageReceiver object for receiving messages from Solace. If it is the first time * this method is used, the receiver is created from the session instance, otherwise it returns diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageReceiver.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageReceiver.java index d548d2049a5b..d74f3cae89fe 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageReceiver.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageReceiver.java @@ -24,12 +24,8 @@ import java.io.IOException; import org.apache.beam.sdk.io.solace.RetryCallableManager; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class SolaceMessageReceiver implements MessageReceiver { - private static final Logger LOG = LoggerFactory.getLogger(SolaceMessageReceiver.class); - public static final int DEFAULT_ADVANCE_TIMEOUT_IN_MILLIS = 100; private final FlowReceiver flowReceiver; private final RetryCallableManager retryCallableManager = RetryCallableManager.create(); @@ -52,19 +48,14 @@ private void startFlowReceiver() { ImmutableSet.of(JCSMPException.class)); } - @Override - public boolean isClosed() { - return flowReceiver == null || flowReceiver.isClosed(); - } - @Override public BytesXMLMessage receive() throws IOException { try { return flowReceiver.receive(DEFAULT_ADVANCE_TIMEOUT_IN_MILLIS); } catch (StaleSessionException e) { - LOG.warn("SolaceIO: Caught StaleSessionException, restarting the FlowReceiver."); startFlowReceiver(); - throw new IOException(e); + throw new IOException( + "SolaceIO: Caught StaleSessionException, restarting the FlowReceiver.", e); } catch (JCSMPException e) { throw new IOException(e); } @@ -72,8 +63,6 @@ public BytesXMLMessage receive() throws IOException { @Override public void close() { - if (!isClosed()) { - this.flowReceiver.close(); - } + flowReceiver.close(); } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/SolaceCheckpointMark.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/SolaceCheckpointMark.java index 77f6eed8f62c..a913fd6133ea 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/SolaceCheckpointMark.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/SolaceCheckpointMark.java @@ -18,17 +18,16 @@ package org.apache.beam.sdk.io.solace.read; import com.solacesystems.jcsmp.BytesXMLMessage; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.Queue; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.DefaultCoder; import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Checkpoint for an unbounded Solace source. Consists of the Solace messages waiting to be @@ -38,10 +37,8 @@ @Internal @VisibleForTesting public class SolaceCheckpointMark implements UnboundedSource.CheckpointMark { - private transient AtomicBoolean activeReader; - // BytesXMLMessage is not serializable so if a job restarts from the checkpoint, we cannot retry - // these messages here. We relay on Solace's retry mechanism. - private transient ArrayDeque ackQueue; + private static final Logger LOG = LoggerFactory.getLogger(SolaceCheckpointMark.class); + private transient Queue safeToAck; @SuppressWarnings("initialization") // Avro will set the fields by breaking abstraction private SolaceCheckpointMark() {} @@ -49,25 +46,24 @@ private SolaceCheckpointMark() {} /** * Creates a new {@link SolaceCheckpointMark}. * - * @param activeReader {@link AtomicBoolean} indicating if the related reader is active. The - * reader creating the messages has to be active to acknowledge the messages. - * @param ackQueue {@link List} of {@link BytesXMLMessage} to be acknowledged. + * @param safeToAck - a queue of {@link BytesXMLMessage} to be acknowledged. */ - SolaceCheckpointMark(AtomicBoolean activeReader, List ackQueue) { - this.activeReader = activeReader; - this.ackQueue = new ArrayDeque<>(ackQueue); + SolaceCheckpointMark(Queue safeToAck) { + this.safeToAck = safeToAck; } @Override public void finalizeCheckpoint() { - if (activeReader == null || !activeReader.get() || ackQueue == null) { - return; - } - - while (!ackQueue.isEmpty()) { - BytesXMLMessage msg = ackQueue.poll(); - if (msg != null) { + BytesXMLMessage msg; + while ((msg = safeToAck.poll()) != null) { + try { msg.ackMessage(); + } catch (IllegalStateException e) { + LOG.error( + "SolaceIO.Read: cannot acknowledge the message with applicationMessageId={}, ackMessageId={}. It will not be retried.", + msg.getApplicationMessageId(), + msg.getAckMessageId(), + e); } } } @@ -84,15 +80,11 @@ public boolean equals(@Nullable Object o) { return false; } SolaceCheckpointMark that = (SolaceCheckpointMark) o; - // Needed to convert to ArrayList because ArrayDeque.equals checks only for reference, not - // content. - ArrayList ackList = new ArrayList<>(ackQueue); - ArrayList thatAckList = new ArrayList<>(that.ackQueue); - return Objects.equals(activeReader, that.activeReader) && Objects.equals(ackList, thatAckList); + return Objects.equals(safeToAck, that.safeToAck); } @Override public int hashCode() { - return Objects.hash(activeReader, ackQueue); + return Objects.hash(safeToAck); } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java index a421970370da..dc84e0a07017 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java @@ -22,17 +22,26 @@ import com.solacesystems.jcsmp.BytesXMLMessage; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.List; import java.util.NoSuchElementException; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.Queue; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.io.solace.broker.SempClient; import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalNotification; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; import org.slf4j.Logger; @@ -46,48 +55,92 @@ class UnboundedSolaceReader extends UnboundedReader { private final UnboundedSolaceSource currentSource; private final WatermarkPolicy watermarkPolicy; private final SempClient sempClient; + private final UUID readerUuid; + private final SessionServiceFactory sessionServiceFactory; private @Nullable BytesXMLMessage solaceOriginalRecord; private @Nullable T solaceMappedRecord; - private @Nullable SessionService sessionService; - AtomicBoolean active = new AtomicBoolean(true); /** - * Queue to place advanced messages before {@link #getCheckpointMark()} be called non-concurrent - * queue, should only be accessed by the reader thread A given {@link UnboundedReader} object will - * only be accessed by a single thread at once. + * Queue to place advanced messages before {@link #getCheckpointMark()} is called. CAUTION: + * Accessed by both reader and checkpointing threads. */ - private final java.util.Queue elementsToCheckpoint = new ArrayDeque<>(); + private final Queue safeToAckMessages = new ConcurrentLinkedQueue<>(); + + /** + * Queue for messages that were ingested in the {@link #advance()} method, but not sent yet to a + * {@link SolaceCheckpointMark}. + */ + private final Queue receivedMessages = new ArrayDeque<>(); + + private static final Cache sessionServiceCache; + private static final ScheduledExecutorService cleanUpThread = Executors.newScheduledThreadPool(1); + + static { + Duration cacheExpirationTimeout = Duration.ofMinutes(1); + sessionServiceCache = + CacheBuilder.newBuilder() + .expireAfterAccess(cacheExpirationTimeout) + .removalListener( + (RemovalNotification notification) -> { + LOG.info( + "SolaceIO.Read: Closing session for the reader with uuid {} as it has been idle for over {}.", + notification.getKey(), + cacheExpirationTimeout); + SessionService sessionService = notification.getValue(); + if (sessionService != null) { + sessionService.close(); + } + }) + .build(); + + startCleanUpThread(); + } + + @SuppressWarnings("FutureReturnValueIgnored") + private static void startCleanUpThread() { + cleanUpThread.scheduleAtFixedRate(sessionServiceCache::cleanUp, 1, 1, TimeUnit.MINUTES); + } public UnboundedSolaceReader(UnboundedSolaceSource currentSource) { this.currentSource = currentSource; this.watermarkPolicy = WatermarkPolicy.create( currentSource.getTimestampFn(), currentSource.getWatermarkIdleDurationThreshold()); - this.sessionService = currentSource.getSessionServiceFactory().create(); + this.sessionServiceFactory = currentSource.getSessionServiceFactory(); this.sempClient = currentSource.getSempClientFactory().create(); + this.readerUuid = UUID.randomUUID(); + } + + private SessionService getSessionService() { + try { + return sessionServiceCache.get( + readerUuid, + () -> { + LOG.info("SolaceIO.Read: creating a new session for reader with uuid {}.", readerUuid); + SessionService sessionService = sessionServiceFactory.create(); + sessionService.connect(); + sessionService.getReceiver().start(); + return sessionService; + }); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } } @Override public boolean start() { - populateSession(); - checkNotNull(sessionService).getReceiver().start(); + // Create and initialize SessionService with Receiver + getSessionService(); return advance(); } - public void populateSession() { - if (sessionService == null) { - sessionService = getCurrentSource().getSessionServiceFactory().create(); - } - if (sessionService.isClosed()) { - checkNotNull(sessionService).connect(); - } - } - @Override public boolean advance() { + finalizeReadyMessages(); + BytesXMLMessage receivedXmlMessage; try { - receivedXmlMessage = checkNotNull(sessionService).getReceiver().receive(); + receivedXmlMessage = getSessionService().getReceiver().receive(); } catch (IOException e) { LOG.warn("SolaceIO.Read: Exception when pulling messages from the broker.", e); return false; @@ -96,23 +149,40 @@ public boolean advance() { if (receivedXmlMessage == null) { return false; } - elementsToCheckpoint.add(receivedXmlMessage); solaceOriginalRecord = receivedXmlMessage; solaceMappedRecord = getCurrentSource().getParseFn().apply(receivedXmlMessage); - watermarkPolicy.update(solaceMappedRecord); + receivedMessages.add(receivedXmlMessage); + return true; } @Override public void close() { - active.set(false); - checkNotNull(sessionService).close(); + finalizeReadyMessages(); + sessionServiceCache.invalidate(readerUuid); + } + + public void finalizeReadyMessages() { + BytesXMLMessage msg; + while ((msg = safeToAckMessages.poll()) != null) { + try { + msg.ackMessage(); + } catch (IllegalStateException e) { + LOG.error( + "SolaceIO.Read: failed to acknowledge the message with applicationMessageId={}, ackMessageId={}. Returning the message to queue to retry.", + msg.getApplicationMessageId(), + msg.getAckMessageId(), + e); + safeToAckMessages.add(msg); // In case the error was transient, might succeed later + break; // Commit is only best effort + } + } } @Override public Instant getWatermark() { // should be only used by a test receiver - if (checkNotNull(sessionService).getReceiver().isEOF()) { + if (getSessionService().getReceiver().isEOF()) { return BoundedWindow.TIMESTAMP_MAX_VALUE; } return watermarkPolicy.getWatermark(); @@ -120,14 +190,9 @@ public Instant getWatermark() { @Override public UnboundedSource.CheckpointMark getCheckpointMark() { - List ackQueue = new ArrayList<>(); - while (!elementsToCheckpoint.isEmpty()) { - BytesXMLMessage msg = elementsToCheckpoint.poll(); - if (msg != null) { - ackQueue.add(msg); - } - } - return new SolaceCheckpointMark(active, ackQueue); + safeToAckMessages.addAll(receivedMessages); + receivedMessages.clear(); + return new SolaceCheckpointMark(safeToAckMessages); } @Override diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java index 38b4953a5984..7631d32f63cc 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java @@ -40,11 +40,6 @@ public void close() { throw new UnsupportedOperationException(exceptionMessage); } - @Override - public boolean isClosed() { - throw new UnsupportedOperationException(exceptionMessage); - } - @Override public MessageReceiver getReceiver() { throw new UnsupportedOperationException(exceptionMessage); diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java index bd52dee7ea86..6d28bcefc84c 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java @@ -77,11 +77,6 @@ public abstract Builder mockProducerFn( @Override public void close() {} - @Override - public boolean isClosed() { - return false; - } - @Override public MessageReceiver getReceiver() { if (messageReceiver == null) { @@ -131,11 +126,6 @@ public MockReceiver( @Override public void start() {} - @Override - public boolean isClosed() { - return false; - } - @Override public BytesXMLMessage receive() throws IOException { return getRecordFn.apply(counter.getAndIncrement()); diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java index c718c55e1b48..a1f80932eddf 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java @@ -447,25 +447,29 @@ public void testCheckpointMarkAndFinalizeSeparately() throws Exception { // start the reader and move to the first record assertTrue(reader.start()); - // consume 3 messages (NB: start already consumed the first message) + // consume 3 messages (NB: #start() already consumed the first message) for (int i = 0; i < 3; i++) { assertTrue(String.format("Failed at %d-th message", i), reader.advance()); } - // create checkpoint but don't finalize yet + // #advance() was called, but the messages were not ready to be acknowledged. + assertEquals(0, countAckMessages.get()); + + // mark all consumed messages as ready to be acknowledged CheckpointMark checkpointMark = reader.getCheckpointMark(); - // consume 2 more messages - reader.advance(); + // consume 1 more message. This will call #ackMsg() on messages that were ready to be acked. reader.advance(); + assertEquals(4, countAckMessages.get()); - // check if messages are still not acknowledged - assertEquals(0, countAckMessages.get()); + // consume 1 more message. No change in the acknowledged messages. + reader.advance(); + assertEquals(4, countAckMessages.get()); // acknowledge from the first checkpoint checkpointMark.finalizeCheckpoint(); - - // only messages from the first checkpoint are acknowledged + // No change in the acknowledged messages, because they were acknowledged in the #advance() + // method. assertEquals(4, countAckMessages.get()); } diff --git a/sdks/java/testing/load-tests/build.gradle b/sdks/java/testing/load-tests/build.gradle index d1439bafb748..c74c7301db74 100644 --- a/sdks/java/testing/load-tests/build.gradle +++ b/sdks/java/testing/load-tests/build.gradle @@ -64,6 +64,10 @@ configurations { gradleRun } +def excludeNetty = { + exclude group: "io.netty", module: "*" // exclude more recent Netty version +} + dependencies { implementation enforcedPlatform(library.java.google_cloud_platform_libraries_bom) @@ -73,8 +77,9 @@ dependencies { implementation project(":sdks:java:testing:test-utils") implementation project(":sdks:java:io:google-cloud-platform") implementation project(":sdks:java:io:kafka") - implementation project(":sdks:java:io:kinesis") - implementation library.java.aws_java_sdk_core + implementation project(":sdks:java:io:amazon-web-services2") + implementation library.java.aws_java_sdk2_auth, excludeNetty + implementation library.java.aws_java_sdk2_regions, excludeNetty implementation library.java.google_cloud_core implementation library.java.joda_time implementation library.java.vendored_guava_32_1_2_jre diff --git a/sdks/java/testing/load-tests/src/main/java/org/apache/beam/sdk/loadtests/SyntheticDataPublisher.java b/sdks/java/testing/load-tests/src/main/java/org/apache/beam/sdk/loadtests/SyntheticDataPublisher.java index 525582451bd9..3bd87480b8bc 100644 --- a/sdks/java/testing/load-tests/src/main/java/org/apache/beam/sdk/loadtests/SyntheticDataPublisher.java +++ b/sdks/java/testing/load-tests/src/main/java/org/apache/beam/sdk/loadtests/SyntheticDataPublisher.java @@ -20,7 +20,6 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray; -import com.amazonaws.regions.Regions; import java.io.IOException; import java.util.Arrays; import java.util.HashMap; @@ -30,10 +29,11 @@ import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.io.aws2.common.ClientConfiguration; +import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO; import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage; import org.apache.beam.sdk.io.kafka.KafkaIO; -import org.apache.beam.sdk.io.kinesis.KinesisIO; import org.apache.beam.sdk.io.synthetic.SyntheticBoundedSource; import org.apache.beam.sdk.io.synthetic.SyntheticOptions; import org.apache.beam.sdk.io.synthetic.SyntheticSourceOptions; @@ -47,6 +47,9 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.kafka.common.serialization.StringSerializer; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; /** * Pipeline that generates synthetic data and publishes it in a PubSub or Kafka topic or in a @@ -180,17 +183,21 @@ private static void writeToKafka(PCollection> collection) { } private static void writeToKinesis(PCollection> collection) { + AwsBasicCredentials creds = + AwsBasicCredentials.create(options.getKinesisAwsKey(), options.getKinesisAwsSecret()); + StaticCredentialsProvider provider = StaticCredentialsProvider.create(creds); collection .apply("Map to byte array for Kinesis", MapElements.via(new MapKVToByteArray())) .apply( "Write to Kinesis", - KinesisIO.write() + KinesisIO.write() .withStreamName(options.getKinesisStreamName()) - .withPartitionKey(options.getKinesisPartitionKey()) - .withAWSClientsProvider( - options.getKinesisAwsKey(), - options.getKinesisAwsSecret(), - Regions.fromName(options.getKinesisAwsRegion()))); + .withPartitioner(p -> options.getKinesisPartitionKey()) + .withClientConfiguration( + ClientConfiguration.builder() + .credentialsProvider(provider) + .region(Region.of(options.getKinesisAwsRegion())) + .build())); } private static class MapKVToString extends SimpleFunction, String> { diff --git a/sdks/java/testing/watermarks/build.gradle b/sdks/java/testing/watermarks/build.gradle index c6c2a50279cc..ca774815467a 100644 --- a/sdks/java/testing/watermarks/build.gradle +++ b/sdks/java/testing/watermarks/build.gradle @@ -69,7 +69,6 @@ dependencies { runtimeOnly project(":sdks:java:testing:test-utils") runtimeOnly project(":sdks:java:io:google-cloud-platform") runtimeOnly project(":sdks:java:io:kafka") - runtimeOnly project(":sdks:java:io:kinesis") gradleRun project(project.path) gradleRun project(path: runnerDependency, configuration: runnerConfiguration) diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd b/sdks/python/apache_beam/coders/coder_impl.pxd index 52889fa2fd92..8a28499555c1 100644 --- a/sdks/python/apache_beam/coders/coder_impl.pxd +++ b/sdks/python/apache_beam/coders/coder_impl.pxd @@ -219,6 +219,18 @@ cdef libc.stdint.int64_t MIN_TIMESTAMP_micros cdef libc.stdint.int64_t MAX_TIMESTAMP_micros +cdef class _OrderedUnionCoderImpl(StreamCoderImpl): + cdef tuple _types + cdef tuple _coder_impls + cdef CoderImpl _fallback_coder_impl + + @cython.locals(ix=int, c=CoderImpl) + cpdef encode_to_stream(self, value, OutputStream stream, bint nested) + + @cython.locals(ix=int, c=CoderImpl) + cpdef decode_from_stream(self, InputStream stream, bint nested) + + cdef class WindowedValueCoderImpl(StreamCoderImpl): """A coder for windowed values.""" cdef CoderImpl _value_coder diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 5262e6adf8a6..5dff35052901 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -1421,6 +1421,37 @@ def estimate_size(self, value, nested=False): return size +class _OrderedUnionCoderImpl(StreamCoderImpl): + def __init__(self, coder_impl_types, fallback_coder_impl): + assert len(coder_impl_types) < 128 + self._types, self._coder_impls = zip(*coder_impl_types) + self._fallback_coder_impl = fallback_coder_impl + + def encode_to_stream(self, value, out, nested): + value_t = type(value) + for (ix, t) in enumerate(self._types): + if value_t is t: + out.write_byte(ix) + c = self._coder_impls[ix] # for typing + c.encode_to_stream(value, out, nested) + break + else: + if self._fallback_coder_impl is None: + raise ValueError("No fallback.") + out.write_byte(0xFF) + self._fallback_coder_impl.encode_to_stream(value, out, nested) + + def decode_from_stream(self, in_stream, nested): + ix = in_stream.read_byte() + if ix == 0xFF: + if self._fallback_coder_impl is None: + raise ValueError("No fallback.") + return self._fallback_coder_impl.decode_from_stream(in_stream, nested) + else: + c = self._coder_impls[ix] # for typing + return c.decode_from_stream(in_stream, nested) + + class WindowedValueCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index a0c55da81800..0f2a42686854 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -55,6 +55,7 @@ import google.protobuf.wrappers_pb2 import proto +from google.protobuf import message from apache_beam.coders import coder_impl from apache_beam.coders.avro_record import AvroRecord @@ -65,7 +66,6 @@ from apache_beam.utils import proto_utils if TYPE_CHECKING: - from google.protobuf import message # pylint: disable=ungrouped-imports from apache_beam.coders.typecoders import CoderRegistry from apache_beam.runners.pipeline_context import PipelineContext @@ -1039,11 +1039,18 @@ def __hash__(self): @classmethod def from_type_hint(cls, typehint, unused_registry): - if issubclass(typehint, proto_utils.message_types): + # The typehint must be a strict subclass of google.protobuf.message.Message. + # ProtoCoder cannot work with message.Message itself, as deserialization of + # a serialized proto requires knowledge of the desired concrete proto + # subclass which is not stored in the encoded bytes themselves. If this + # occurs, an error is raised and the system defaults to other fallback + # coders. + if (issubclass(typehint, proto_utils.message_types) and + typehint != message.Message): return cls(typehint) else: raise ValueError(( - 'Expected a subclass of google.protobuf.message.Message' + 'Expected a strict subclass of google.protobuf.message.Message' ', but got a %s' % typehint)) def to_type_hint(self): @@ -1350,12 +1357,48 @@ def __hash__(self): common_urns.coders.INTERVAL_WINDOW.urn, IntervalWindowCoder) +class _OrderedUnionCoder(FastCoder): + def __init__( + self, *coder_types: Tuple[type, Coder], fallback_coder: Optional[Coder]): + self._coder_types = coder_types + self._fallback_coder = fallback_coder + + def _create_impl(self): + return coder_impl._OrderedUnionCoderImpl( + [(t, c.get_impl()) for t, c in self._coder_types], + fallback_coder_impl=self._fallback_coder.get_impl() + if self._fallback_coder else None) + + def is_deterministic(self) -> bool: + return ( + all(c.is_deterministic for _, c in self._coder_types) and ( + self._fallback_coder is None or + self._fallback_coder.is_deterministic())) + + def to_type_hint(self): + return Any + + def __eq__(self, other): + return ( + type(self) == type(other) and + self._coder_types == other._coder_types and + self._fallback_coder == other._fallback_coder) + + def __hash__(self): + return hash((type(self), tuple(self._coder_types), self._fallback_coder)) + + class WindowedValueCoder(FastCoder): """Coder for windowed values.""" def __init__(self, wrapped_value_coder, window_coder=None): # type: (Coder, Optional[Coder]) -> None if not window_coder: - window_coder = PickleCoder() + # Avoid circular imports. + from apache_beam.transforms import window + window_coder = _OrderedUnionCoder( + (window.GlobalWindow, GlobalWindowCoder()), + (window.IntervalWindow, IntervalWindowCoder()), + fallback_coder=PickleCoder()) self.wrapped_value_coder = wrapped_value_coder self.timestamp_coder = TimestampCoder() self.window_coder = window_coder diff --git a/sdks/python/apache_beam/coders/coders_test.py b/sdks/python/apache_beam/coders/coders_test.py index dc9780e36be3..5e5debca36e6 100644 --- a/sdks/python/apache_beam/coders/coders_test.py +++ b/sdks/python/apache_beam/coders/coders_test.py @@ -22,6 +22,7 @@ import proto import pytest +from google.protobuf import message import apache_beam as beam from apache_beam import typehints @@ -86,6 +87,23 @@ def test_proto_coder(self): self.assertEqual(ma, real_coder.decode(real_coder.encode(ma))) self.assertEqual(ma.__class__, real_coder.to_type_hint()) + def test_proto_coder_on_protobuf_message_subclasses(self): + # This replicates a scenario where users provide message.Message as the + # output typehint for a Map function, even though the actual output messages + # are subclasses of message.Message. + ma = test_message.MessageA() + mb = ma.field2.add() + mb.field1 = True + ma.field1 = 'hello world' + + coder = coders_registry.get_coder(message.Message) + # For messages of google.protobuf.message.Message, the fallback coder will + # be FastPrimitivesCoder rather than ProtoCoder. + # See the comment on ProtoCoder.from_type_hint() for further details. + self.assertEqual(coder, coders.FastPrimitivesCoder()) + + self.assertEqual(ma, coder.decode(coder.encode(ma))) + class DeterministicProtoCoderTest(unittest.TestCase): def test_deterministic_proto_coder(self): diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 4bd9698dd57b..f3381cdb1d69 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -769,6 +769,14 @@ def test_decimal_coder(self): test_encodings[idx], base64.b64encode(test_coder.encode(value)).decode().rstrip("=")) + def test_OrderedUnionCoder(self): + test_coder = coders._OrderedUnionCoder((str, coders.StrUtf8Coder()), + (int, coders.VarIntCoder()), + fallback_coder=coders.FloatCoder()) + self.check_coder(test_coder, 's') + self.check_coder(test_coder, 123) + self.check_coder(test_coder, 1.5) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/examples/wordcount.py b/sdks/python/apache_beam/examples/wordcount.py index 31407aec6c40..a9138647581c 100644 --- a/sdks/python/apache_beam/examples/wordcount.py +++ b/sdks/python/apache_beam/examples/wordcount.py @@ -45,6 +45,7 @@ from apache_beam.io import WriteToText from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult class WordExtractingDoFn(beam.DoFn): @@ -63,7 +64,7 @@ def process(self, element): return re.findall(r'[\w\']+', element, re.UNICODE) -def run(argv=None, save_main_session=True): +def run(argv=None, save_main_session=True) -> PipelineResult: """Main entry point; defines and runs the wordcount pipeline.""" parser = argparse.ArgumentParser() parser.add_argument( @@ -83,27 +84,31 @@ def run(argv=None, save_main_session=True): pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session - # The pipeline will be run on exiting the with block. - with beam.Pipeline(options=pipeline_options) as p: + pipeline = beam.Pipeline(options=pipeline_options) - # Read the text file[pattern] into a PCollection. - lines = p | 'Read' >> ReadFromText(known_args.input) + # Read the text file[pattern] into a PCollection. + lines = pipeline | 'Read' >> ReadFromText(known_args.input) - counts = ( - lines - | 'Split' >> (beam.ParDo(WordExtractingDoFn()).with_output_types(str)) - | 'PairWithOne' >> beam.Map(lambda x: (x, 1)) - | 'GroupAndSum' >> beam.CombinePerKey(sum)) + counts = ( + lines + | 'Split' >> (beam.ParDo(WordExtractingDoFn()).with_output_types(str)) + | 'PairWithOne' >> beam.Map(lambda x: (x, 1)) + | 'GroupAndSum' >> beam.CombinePerKey(sum)) - # Format the counts into a PCollection of strings. - def format_result(word, count): - return '%s: %d' % (word, count) + # Format the counts into a PCollection of strings. + def format_result(word, count): + return '%s: %d' % (word, count) - output = counts | 'Format' >> beam.MapTuple(format_result) + output = counts | 'Format' >> beam.MapTuple(format_result) - # Write the output using a "Write" transform that has side effects. - # pylint: disable=expression-not-assigned - output | 'Write' >> WriteToText(known_args.output) + # Write the output using a "Write" transform that has side effects. + # pylint: disable=expression-not-assigned + output | 'Write' >> WriteToText(known_args.output) + + # Execute the pipeline and return the result. + result = pipeline.run() + result.wait_until_finish() + return result if __name__ == '__main__': diff --git a/sdks/python/apache_beam/io/aws/s3filesystem.py b/sdks/python/apache_beam/io/aws/s3filesystem.py index ffbce5893a96..584263ec241e 100644 --- a/sdks/python/apache_beam/io/aws/s3filesystem.py +++ b/sdks/python/apache_beam/io/aws/s3filesystem.py @@ -18,6 +18,7 @@ """S3 file system implementation for accessing files on AWS S3.""" # pytype: skip-file +import traceback from apache_beam.io.aws import s3io from apache_beam.io.filesystem import BeamIOError @@ -315,14 +316,13 @@ def delete(self, paths): if exceptions: raise BeamIOError("Delete operation failed", exceptions) - def report_lineage(self, path, lineage, level=None): + def report_lineage(self, path, lineage): try: components = s3io.parse_s3_path(path, object_optional=True) except ValueError: # report lineage is fail-safe + traceback.print_exc() return - if level == FileSystem.LineageLevel.TOP_LEVEL or \ - (len(components) > 1 and components[-1] == ''): - # bucket only + if components and not components[-1]: components = components[:-1] - lineage.add('s3', *components) + lineage.add('s3', *components, last_segment_sep='/') diff --git a/sdks/python/apache_beam/io/aws/s3filesystem_test.py b/sdks/python/apache_beam/io/aws/s3filesystem_test.py index 87403f482bd2..036727cd7a70 100644 --- a/sdks/python/apache_beam/io/aws/s3filesystem_test.py +++ b/sdks/python/apache_beam/io/aws/s3filesystem_test.py @@ -272,7 +272,8 @@ def test_lineage(self): def _verify_lineage(self, uri, expected_segments): lineage_mock = mock.MagicMock() self.fs.report_lineage(uri, lineage_mock) - lineage_mock.add.assert_called_once_with("s3", *expected_segments) + lineage_mock.add.assert_called_once_with( + "s3", *expected_segments, last_segment_sep='/') if __name__ == '__main__': diff --git a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py index 4495245dc54a..4b7462cae03c 100644 --- a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py +++ b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py @@ -18,6 +18,7 @@ """Azure Blob Storage Implementation for accesing files on Azure Blob Storage. """ +import traceback from apache_beam.io.azure import blobstorageio from apache_beam.io.filesystem import BeamIOError @@ -317,15 +318,14 @@ def delete(self, paths): if exceptions: raise BeamIOError("Delete operation failed", exceptions) - def report_lineage(self, path, lineage, level=None): + def report_lineage(self, path, lineage): try: components = blobstorageio.parse_azfs_path( path, blob_optional=True, get_account=True) except ValueError: # report lineage is fail-safe + traceback.print_exc() return - if level == FileSystem.LineageLevel.TOP_LEVEL \ - or(len(components) > 1 and components[-1] == ''): - # bucket only + if components and not components[-1]: components = components[:-1] - lineage.add('abs', *components) + lineage.add('abs', *components, last_segment_sep='/') diff --git a/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py b/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py index 138fe5f78b20..c3418e137e87 100644 --- a/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py +++ b/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py @@ -330,7 +330,8 @@ def test_lineage(self): def _verify_lineage(self, uri, expected_segments): lineage_mock = mock.MagicMock() self.fs.report_lineage(uri, lineage_mock) - lineage_mock.add.assert_called_once_with("abs", *expected_segments) + lineage_mock.add.assert_called_once_with( + "abs", *expected_segments, last_segment_sep='/') if __name__ == '__main__': diff --git a/sdks/python/apache_beam/io/external/xlang_kinesisio_it_test.py b/sdks/python/apache_beam/io/external/xlang_kinesisio_it_test.py index 151d63d84684..c9181fb2a721 100644 --- a/sdks/python/apache_beam/io/external/xlang_kinesisio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_kinesisio_it_test.py @@ -64,7 +64,7 @@ DockerContainer = None # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports -LOCALSTACK_VERSION = '0.11.3' +LOCALSTACK_VERSION = '3.8.1' NUM_RECORDS = 10 MAX_READ_TIME = 5 * 60 * 1000 # 5min NOW_SECONDS = time.time() @@ -116,9 +116,7 @@ def run_kinesis_write(self): region=self.aws_region, service_endpoint=self.aws_service_endpoint, verify_certificate=(not self.use_localstack), - partition_key='1', - producer_properties=self.producer_properties, - )) + partition_key='1')) def run_kinesis_read(self): records = [RECORD + str(i).encode() for i in range(NUM_RECORDS)] @@ -145,12 +143,11 @@ def run_kinesis_read(self): def set_localstack(self): self.localstack = DockerContainer('localstack/localstack:{}' - .format(LOCALSTACK_VERSION))\ - .with_env('SERVICES', 'kinesis')\ - .with_env('KINESIS_PORT', '4568')\ - .with_env('USE_SSL', 'true')\ - .with_exposed_ports(4568)\ - .with_volume_mapping('/var/run/docker.sock', '/var/run/docker.sock', 'rw') + .format(LOCALSTACK_VERSION))\ + .with_bind_ports(4566, 4566) + + for i in range(4510, 4560): + self.localstack = self.localstack.with_bind_ports(i, i) # Repeat if ReadTimeout is raised. for i in range(4): @@ -164,7 +161,7 @@ def set_localstack(self): self.aws_service_endpoint = 'https://{}:{}'.format( self.localstack.get_container_host_ip(), - self.localstack.get_exposed_port('4568'), + self.localstack.get_exposed_port('4566'), ) def setUp(self): @@ -219,10 +216,6 @@ def setUp(self): self.aws_service_endpoint = known_args.aws_service_endpoint self.use_localstack = not known_args.use_real_aws self.expansion_service = known_args.expansion_service - self.producer_properties = { - 'CollectionMaxCount': str(NUM_RECORDS), - 'ConnectTimeout': str(MAX_READ_TIME), - } if self.use_localstack: self.set_localstack() diff --git a/sdks/python/apache_beam/io/filebasedsink.py b/sdks/python/apache_beam/io/filebasedsink.py index f9d4303c8c78..eb433bd60583 100644 --- a/sdks/python/apache_beam/io/filebasedsink.py +++ b/sdks/python/apache_beam/io/filebasedsink.py @@ -286,24 +286,16 @@ def _check_state_for_finalize_write(self, writer_results, num_shards): def _report_sink_lineage(self, dst_glob, dst_files): """ - Report sink Lineage. Report every file if number of files no more than 100, - otherwise only report at directory level. + Report sink Lineage. Report every file if number of files no more than 10, + otherwise only report glob. """ - if len(dst_files) <= 100: + # There is rollup at the higher level, but this loses glob information. + # Better to report multiple globs than just the parent directory. + if len(dst_files) <= 10: for dst in dst_files: FileSystems.report_sink_lineage(dst) else: - dst = dst_glob - # dst_glob has a wildcard for shard number (see _shard_name_template) - sep = dst_glob.find('*') - if sep > 0: - dst = dst[:sep] - try: - dst, _ = FileSystems.split(dst) - except ValueError: - return # lineage report is fail-safe - - FileSystems.report_sink_lineage(dst) + FileSystems.report_sink_lineage(dst_glob) @check_accessible(['file_path_prefix']) def finalize_write( diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index a02bc6de32c7..49b1b1d125f1 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -39,7 +39,6 @@ from apache_beam.io import range_trackers from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.filesystem import FileMetadata -from apache_beam.io.filesystem import FileSystem from apache_beam.io.filesystems import FileSystems from apache_beam.io.restriction_trackers import OffsetRange from apache_beam.options.value_provider import StaticValueProvider @@ -170,37 +169,11 @@ def _get_concat_source(self) -> concat_source.ConcatSource: splittable=splittable) single_file_sources.append(single_file_source) - self._report_source_lineage(files_metadata) + FileSystems.report_source_lineage(pattern) self._concat_source = concat_source.ConcatSource(single_file_sources) return self._concat_source - def _report_source_lineage(self, files_metadata): - """ - Report source Lineage. depend on the number of files, report full file - name, only dir, or only top level - """ - if len(files_metadata) <= 100: - for file_metadata in files_metadata: - FileSystems.report_source_lineage(file_metadata.path) - else: - size_track = set() - for file_metadata in files_metadata: - if len(size_track) >= 100: - FileSystems.report_source_lineage( - file_metadata.path, level=FileSystem.LineageLevel.TOP_LEVEL) - return - - try: - base, _ = FileSystems.split(file_metadata.path) - except ValueError: - pass - else: - size_track.add(base) - - for base in size_track: - FileSystems.report_source_lineage(base) - def open_file(self, file_name): return FileSystems.open( file_name, @@ -382,7 +355,7 @@ def process(self, element: Union[str, FileMetadata], *args, match_results = FileSystems.match([element]) metadata_list = match_results[0].metadata_list for metadata in metadata_list: - self._report_source_lineage(metadata.path) + FileSystems.report_source_lineage(metadata.path) splittable = ( self._splittable and _determine_splittability_from_compression_type( @@ -397,28 +370,6 @@ def process(self, element: Union[str, FileMetadata], *args, metadata, OffsetRange(0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY)) - def _report_source_lineage(self, path): - """ - Report source Lineage. Due to the size limit of Beam metrics, report full - file name or only top level depend on the number of files. - - * Number of files<=100, report full file paths; - - * Otherwise, report top level only. - """ - if self._size_track is None: - self._size_track = set() - elif len(self._size_track) == 0: - FileSystems.report_source_lineage( - path, level=FileSystem.LineageLevel.TOP_LEVEL) - return - - self._size_track.add(path) - FileSystems.report_source_lineage(path) - - if len(self._size_track) >= 100: - self._size_track.clear() - class _ReadRange(DoFn): def __init__( diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index 840fdf3309e7..bdc25dcf0fe5 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -934,11 +934,7 @@ def delete(self, paths): """ raise NotImplementedError - class LineageLevel: - FILE = 'FILE' - TOP_LEVEL = 'TOP_LEVEL' - - def report_lineage(self, path, unused_lineage, level=None): + def report_lineage(self, path, unused_lineage): """ Report Lineage metrics for path. diff --git a/sdks/python/apache_beam/io/filesystems.py b/sdks/python/apache_beam/io/filesystems.py index 87f45f3308ee..1d64f88684b8 100644 --- a/sdks/python/apache_beam/io/filesystems.py +++ b/sdks/python/apache_beam/io/filesystems.py @@ -391,27 +391,21 @@ def get_chunk_size(path): return filesystem.CHUNK_SIZE @staticmethod - def report_source_lineage(path, level=None): + def report_source_lineage(path): """ - Report source :class:`~apache_beam.metrics.metric.LineageLevel`. + Report source :class:`~apache_beam.metrics.metric.Lineage`. Args: path: string path to be reported. - level: the level of file path. default to - :class:`~apache_beam.io.filesystem.FileSystem.LineageLevel`.FILE. """ - filesystem = FileSystems.get_filesystem(path) - filesystem.report_lineage(path, Lineage.sources(), level=level) + FileSystems.get_filesystem(path).report_lineage(path, Lineage.sources()) @staticmethod - def report_sink_lineage(path, level=None): + def report_sink_lineage(path): """ Report sink :class:`~apache_beam.metrics.metric.Lineage`. Args: path: string path to be reported. - level: the level of file path. default to - :class:`~apache_beam.io.filesystem.FileSystem.Lineage`.FILE. """ - filesystem = FileSystems.get_filesystem(path) - filesystem.report_lineage(path, Lineage.sinks(), level=level) + FileSystems.get_filesystem(path).report_lineage(path, Lineage.sinks()) diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 11e0d098b2f3..9f60b5af6726 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -1163,7 +1163,7 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): self.table_reference.datasetId, self.table_reference.tableId) Lineage.sources().add( - "bigquery", + 'bigquery', self.table_reference.projectId, self.table_reference.datasetId, self.table_reference.tableId) diff --git a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py index 325f70ddfd96..96aca2c410d8 100644 --- a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py +++ b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py @@ -26,6 +26,7 @@ # pytype: skip-file +import traceback from typing import BinaryIO # pylint: disable=unused-import from apache_beam.io.filesystem import BeamIOError @@ -366,14 +367,13 @@ def delete(self, paths): if exceptions: raise BeamIOError("Delete operation failed", exceptions) - def report_lineage(self, path, lineage, level=None): + def report_lineage(self, path, lineage): try: components = gcsio.parse_gcs_path(path, object_optional=True) except ValueError: # report lineage is fail-safe + traceback.print_exc() return - if level == FileSystem.LineageLevel.TOP_LEVEL \ - or(len(components) > 1 and components[-1] == ''): - # bucket only + if components and not components[-1]: components = components[:-1] - lineage.add('gcs', *components) + lineage.add('gcs', *components, last_segment_sep='/') diff --git a/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py b/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py index ec7fa94b05fd..ade8529dcac8 100644 --- a/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py @@ -382,7 +382,8 @@ def test_lineage(self): def _verify_lineage(self, uri, expected_segments): lineage_mock = mock.MagicMock() self.fs.report_lineage(uri, lineage_mock) - lineage_mock.add.assert_called_once_with("gcs", *expected_segments) + lineage_mock.add.assert_called_once_with( + "gcs", *expected_segments, last_segment_sep='/') if __name__ == '__main__': diff --git a/sdks/python/apache_beam/io/kinesis.py b/sdks/python/apache_beam/io/kinesis.py index bc5e1fa787b4..ce0bb2623a38 100644 --- a/sdks/python/apache_beam/io/kinesis.py +++ b/sdks/python/apache_beam/io/kinesis.py @@ -49,7 +49,8 @@ In this option, Python SDK will either download (for released Beam version) or build (when running from a Beam Git clone) a expansion service jar and use that to expand transforms. Currently Kinesis transforms use the - 'beam-sdks-java-io-kinesis-expansion-service' jar for this purpose. + 'beam-sdks-java-io-amazon-web-services2-expansion-service' jar for this + purpose. *Option 2: specify a custom expansion service* @@ -81,7 +82,6 @@ import logging import time -from typing import Mapping from typing import NamedTuple from typing import Optional @@ -99,7 +99,7 @@ def default_io_expansion_service(): return BeamJarExpansionService( - 'sdks:java:io:kinesis:expansion-service:shadowJar') + 'sdks:java:io:amazon-web-services2:expansion-service:shadowJar') WriteToKinesisSchema = NamedTuple( @@ -112,7 +112,6 @@ def default_io_expansion_service(): ('partition_key', str), ('service_endpoint', Optional[str]), ('verify_certificate', Optional[bool]), - ('producer_properties', Optional[Mapping[str, str]]), ], ) @@ -123,7 +122,7 @@ class WriteToKinesis(ExternalTransform): Experimental; no backwards compatibility guarantees. """ - URN = 'beam:transform:org.apache.beam:kinesis_write:v1' + URN = 'beam:transform:org.apache.beam:kinesis_write:v2' def __init__( self, @@ -148,11 +147,15 @@ def __init__( :param verify_certificate: Enable or disable certificate verification. Never set to False on production. True by default. :param partition_key: Specify default partition key. - :param producer_properties: Specify the configuration properties for Kinesis - Producer Library (KPL) as dictionary. - Example: {'CollectionMaxCount': '1000', 'ConnectTimeout': '10000'} + :param producer_properties: (Deprecated) This option no longer is available + since the AWS IOs upgraded to v2. Trying to set it will lead to an + error. For more info, see https://github.com/apache/beam/issues/33430. :param expansion_service: The address (host:port) of the ExpansionService. """ + if producer_properties is not None: + raise ValueError( + 'producer_properties is no longer supported and will be removed ' + + 'in a future release.') super().__init__( self.URN, NamedTupleBasedPayloadBuilder( @@ -164,7 +167,6 @@ def __init__( partition_key=partition_key, service_endpoint=service_endpoint, verify_certificate=verify_certificate, - producer_properties=producer_properties, )), expansion_service or default_io_expansion_service(), ) @@ -199,7 +201,7 @@ class ReadDataFromKinesis(ExternalTransform): Experimental; no backwards compatibility guarantees. """ - URN = 'beam:transform:org.apache.beam:kinesis_read_data:v1' + URN = 'beam:transform:org.apache.beam:kinesis_read_data:v2' def __init__( self, diff --git a/sdks/python/apache_beam/io/localfilesystem.py b/sdks/python/apache_beam/io/localfilesystem.py index e9fe7dd4b1c2..daf69b8d030c 100644 --- a/sdks/python/apache_beam/io/localfilesystem.py +++ b/sdks/python/apache_beam/io/localfilesystem.py @@ -364,3 +364,6 @@ def try_delete(path): if exceptions: raise BeamIOError("Delete operation failed", exceptions) + + def report_lineage(self, path, lineage): + lineage.add('filesystem', 'localhost', path, last_segment_sep='/') diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd index c583dabeb0c0..ebadeec97984 100644 --- a/sdks/python/apache_beam/metrics/cells.pxd +++ b/sdks/python/apache_beam/metrics/cells.pxd @@ -55,8 +55,23 @@ cdef class StringSetCell(AbstractMetricCell): pass +cdef class BoundedTrieCell(AbstractMetricCell): + pass + + cdef class DistributionData(object): cdef readonly libc.stdint.int64_t sum cdef readonly libc.stdint.int64_t count cdef readonly libc.stdint.int64_t min cdef readonly libc.stdint.int64_t max + + +cdef class _BoundedTrieNode(object): + cdef readonly libc.stdint.int64_t _size + cdef readonly dict _children + cdef readonly bint _truncated + +cdef class BoundedTrieData(object): + cdef readonly libc.stdint.int64_t _bound + cdef readonly object _singleton + cdef readonly _BoundedTrieNode _root diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index 5802c6914eb2..c2c2e8015ef2 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -23,6 +23,7 @@ # pytype: skip-file +import copy import logging import threading import time @@ -31,6 +32,8 @@ from typing import Optional from typing import Set +from apache_beam.portability.api import metrics_pb2 + try: import cython except ImportError: @@ -312,6 +315,35 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): ptransform=transform_id) +class BoundedTrieCell(AbstractMetricCell): + """For internal use only; no backwards-compatibility guarantees. + + Tracks the current value for a BoundedTrie metric. + + Each cell tracks the state of a metric independently per context per bundle. + Therefore, each metric has a different cell in each bundle, that is later + aggregated. + + This class is thread safe. + """ + def __init__(self): + super().__init__(BoundedTrieData) + + def add(self, value): + self.update(value) + + def _update_locked(self, value): + self.data.add(value) + + def to_runner_api_monitoring_info_impl(self, name, transform_id): + from apache_beam.metrics import monitoring_infos + return monitoring_infos.user_bounded_trie( + name.namespace, + name.name, + self.get_cumulative(), + ptransform=transform_id) + + class DistributionResult(object): """The result of a Distribution metric.""" def __init__(self, data): @@ -630,3 +662,244 @@ def singleton(value: str) -> "StringSetData": @staticmethod def identity_element() -> "StringSetData": return StringSetData() + + +class _BoundedTrieNode(object): + def __init__(self): + # invariant: size = len(self.flattened()) = min(1, sum(size of children)) + self._size = 1 + self._children: Optional[dict[str, '_BoundedTrieNode']] = {} + self._truncated = False + + def to_proto(self) -> metrics_pb2.BoundedTrieNode: + return metrics_pb2.BoundedTrieNode( + truncated=self._truncated, + children={ + name: child.to_proto() + for name, child in self._children.items() + } if self._children else None) + + @staticmethod + def from_proto(proto: metrics_pb2.BoundedTrieNode) -> '_BoundedTrieNode': + node = _BoundedTrieNode() + if proto.truncated: + node._truncated = True + node._children = None + else: + node._children = { + name: _BoundedTrieNode.from_proto(child) + for name, + child in proto.children.items() + } + node._size = max(1, sum(child._size for child in node._children.values())) + return node + + def size(self): + return self._size + + def contains(self, segments): + if self._truncated or not segments: + return True + head, *tail = segments + return head in self._children and self._children[head].contains(tail) + + def add(self, segments) -> int: + if self._truncated or not segments: + return 0 + head, *tail = segments + was_empty = not self._children + child = self._children.get(head, None) # type: ignore[union-attr] + if child is None: + child = self._children[head] = _BoundedTrieNode() # type: ignore[index] + delta = 0 if was_empty else 1 + else: + delta = 0 + if tail: + delta += child.add(tail) + self._size += delta + return delta + + def add_all(self, segments_iter): + return sum(self.add(segments) for segments in segments_iter) + + def trim(self) -> int: + if not self._children: + return 0 + max_child = max(self._children.values(), key=lambda child: child._size) + if max_child._size == 1: + delta = 1 - self._size + self._truncated = True + self._children = None + else: + delta = max_child.trim() + self._size += delta + return delta + + def merge(self, other: '_BoundedTrieNode') -> int: + if self._truncated: + delta = 0 + elif other._truncated: + delta = 1 - self._size + self._truncated = True + self._children = None + elif not other._children: + delta = 0 + elif not self._children: + self._children = other._children + delta = other._size - self._size + else: + delta = 0 + other_child: '_BoundedTrieNode' + self_child: Optional['_BoundedTrieNode'] + for prefix, other_child in other._children.items(): + self_child = self._children.get(prefix, None) + if self_child is None: + self._children[prefix] = other_child + delta += other_child._size + else: + delta += self_child.merge(other_child) + self._size += delta + return delta + + def flattened(self): + if self._truncated: + yield (True, ) + elif not self._children: + yield (False, ) + else: + for prefix, child in sorted(self._children.items()): + for flattened in child.flattened(): + yield (prefix, ) + flattened + + def __hash__(self): + return self._truncated or hash(sorted(self._children.items())) + + def __eq__(self, other): + if isinstance(other, _BoundedTrieNode): + return ( + self._truncated == other._truncated and + self._children == other._children) + else: + return False + + def __repr__(self): + return repr(set(''.join(str(s) for s in t) for t in self.flattened())) + + +class BoundedTrieData(object): + _DEFAULT_BOUND = 100 + + def __init__(self, *, root=None, singleton=None, bound=_DEFAULT_BOUND): + self._singleton = singleton + self._root = root + self._bound = bound + assert singleton is None or root is None + + def size(self): + if self._singleton is not None: + return 1 + elif self._root is not None: + return self._root.size() + else: + return 0 + + def contains(self, value): + if self._singleton is not None: + return tuple(value) == self._singleton + elif self._root is not None: + return self._root.contains(value) + else: + return False + + def flattened(self): + return self.as_trie().flattened() + + def to_proto(self) -> metrics_pb2.BoundedTrie: + return metrics_pb2.BoundedTrie( + bound=self._bound, + singleton=self._singleton if self._singleton else None, + root=self._root.to_proto() if self._root else None) + + @staticmethod + def from_proto(proto: metrics_pb2.BoundedTrie) -> 'BoundedTrieData': + return BoundedTrieData( + bound=proto.bound, + singleton=tuple(proto.singleton) if proto.singleton else None, + root=( + _BoundedTrieNode.from_proto(proto.root) + if proto.HasField('root') else None)) + + def as_trie(self): + if self._root is not None: + return self._root + else: + root = _BoundedTrieNode() + if self._singleton is not None: + root.add(self._singleton) + return root + + def __eq__(self, other: object) -> bool: + if isinstance(other, BoundedTrieData): + return self.as_trie() == other.as_trie() + else: + return False + + def __hash__(self) -> int: + return hash(self.as_trie()) + + def __repr__(self) -> str: + return 'BoundedTrieData({})'.format(self.as_trie()) + + def get_cumulative(self) -> "BoundedTrieData": + return copy.deepcopy(self) + + def get_result(self) -> Set[tuple]: + if self._root is None: + if self._singleton is None: + return set() + else: + return set([self._singleton + (False, )]) + else: + return set(self._root.flattened()) + + def add(self, segments): + if self._root is None and self._singleton is None: + self._singleton = segments + elif self._singleton is not None and self._singleton == segments: + # Optimize for the common case of re-adding the same value. + return + else: + if self._root is None: + self._root = self.as_trie() + self._singleton = None + self._root.add(segments) + if self._root._size > self._bound: + self._root.trim() + + def combine(self, other: "BoundedTrieData") -> "BoundedTrieData": + if self._root is None and self._singleton is None: + return other + elif other._root is None and other._singleton is None: + return self + else: + if self._root is None and other._root is not None: + self, other = other, self + combined = copy.deepcopy(self.as_trie()) + if other._root is not None: + combined.merge(other._root) + else: + combined.add(other._singleton) + self._bound = min(self._bound, other._bound) + while combined._size > self._bound: + combined.trim() + return BoundedTrieData(root=combined) + + @staticmethod + def singleton(value: str) -> "BoundedTrieData": + s = BoundedTrieData() + s.add(value) + return s + + @staticmethod + def identity_element() -> "BoundedTrieData": + return BoundedTrieData() diff --git a/sdks/python/apache_beam/metrics/cells_test.py b/sdks/python/apache_beam/metrics/cells_test.py index d1ee37b8ed82..1cd15fced86c 100644 --- a/sdks/python/apache_beam/metrics/cells_test.py +++ b/sdks/python/apache_beam/metrics/cells_test.py @@ -17,9 +17,13 @@ # pytype: skip-file +import copy +import itertools +import random import threading import unittest +from apache_beam.metrics.cells import BoundedTrieData from apache_beam.metrics.cells import CounterCell from apache_beam.metrics.cells import DistributionCell from apache_beam.metrics.cells import DistributionData @@ -27,6 +31,7 @@ from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import StringSetCell from apache_beam.metrics.cells import StringSetData +from apache_beam.metrics.cells import _BoundedTrieNode from apache_beam.metrics.metricbase import MetricName @@ -203,5 +208,235 @@ def test_add_size_tracked_correctly(self): self.assertEqual(s.data.string_size, 3) +class TestBoundedTrieNode(unittest.TestCase): + @classmethod + def random_segments_fixed_depth(cls, n, depth, overlap, rand): + if depth == 0: + yield from ((), ) * n + else: + seen = [] + to_string = lambda ix: chr(ord('a') + ix) if ix < 26 else f'z{ix}' + for suffix in cls.random_segments_fixed_depth(n, depth - 1, overlap, + rand): + if not seen or rand.random() > overlap: + prefix = to_string(len(seen)) + seen.append(prefix) + else: + prefix = rand.choice(seen) + yield (prefix, ) + suffix + + @classmethod + def random_segments(cls, n, min_depth, max_depth, overlap, rand): + for depth, segments in zip( + itertools.cycle(range(min_depth, max_depth + 1)), + cls.random_segments_fixed_depth(n, max_depth, overlap, rand)): + yield segments[:depth] + + def assert_covers(self, node, expected, max_truncated=0): + self.assert_covers_flattened(node.flattened(), expected, max_truncated) + + def assert_covers_flattened(self, flattened, expected, max_truncated=0): + expected = set(expected) + # Split node into the exact and truncated segments. + partitioned = {True: set(), False: set()} + for segments in flattened: + partitioned[segments[-1]].add(segments[:-1]) + exact, truncated = partitioned[False], partitioned[True] + # Check we cover both parts. + self.assertLessEqual(len(truncated), max_truncated, truncated) + self.assertTrue(exact.issubset(expected), exact - expected) + seen_truncated = set() + for segments in expected - exact: + found = 0 + for ix in range(len(segments)): + if segments[:ix] in truncated: + seen_truncated.add(segments[:ix]) + found += 1 + if found != 1: + self.fail( + f"Expected exactly one prefix of {segments} " + f"to occur in {truncated}, found {found}") + self.assertEqual(seen_truncated, truncated, truncated - seen_truncated) + + def run_covers_test(self, flattened, expected, max_truncated): + def parse(s): + return tuple(s.strip('*')) + (s.endswith('*'), ) + + self.assert_covers_flattened([parse(s) for s in flattened], + [tuple(s) for s in expected], + max_truncated) + + def test_covers_exact(self): + self.run_covers_test(['ab', 'ac', 'cd'], ['ab', 'ac', 'cd'], 0) + with self.assertRaises(AssertionError): + self.run_covers_test(['ab', 'ac', 'cd'], ['ac', 'cd'], 0) + with self.assertRaises(AssertionError): + self.run_covers_test(['ab', 'ac'], ['ab', 'ac', 'cd'], 0) + with self.assertRaises(AssertionError): + self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'cd'], 0) + + def test_covers_trunacted(self): + self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'cd'], 1) + self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'abcde', 'cd'], 1) + with self.assertRaises(AssertionError): + self.run_covers_test(['ab', 'ac', 'cd'], ['ac', 'cd'], 1) + with self.assertRaises(AssertionError): + self.run_covers_test(['ab', 'ac'], ['ab', 'ac', 'cd'], 1) + with self.assertRaises(AssertionError): + self.run_covers_test(['a*', 'c*'], ['ab', 'ac', 'cd'], 1) + with self.assertRaises(AssertionError): + self.run_covers_test(['a*', 'c*'], ['ab', 'ac'], 1) + + def run_test(self, to_add): + everything = list(set(to_add)) + all_prefixees = set( + segments[:ix] for segments in everything for ix in range(len(segments))) + everything_deduped = set(everything) - all_prefixees + + # Check basic addition. + node = _BoundedTrieNode() + total_size = node.size() + self.assertEqual(total_size, 1) + for segments in everything: + total_size += node.add(segments) + self.assertEqual(node.size(), len(everything_deduped), node) + self.assertEqual(node.size(), total_size, node) + self.assert_covers(node, everything_deduped) + + # Check merging + node0 = _BoundedTrieNode() + node0.add_all(everything[0::2]) + node1 = _BoundedTrieNode() + node1.add_all(everything[1::2]) + pre_merge_size = node0.size() + merge_delta = node0.merge(node1) + self.assertEqual(node0.size(), pre_merge_size + merge_delta) + self.assertEqual(node0, node) + + # Check trimming. + if node.size() > 1: + trim_delta = node.trim() + self.assertLess(trim_delta, 0, node) + self.assertEqual(node.size(), total_size + trim_delta) + self.assert_covers(node, everything_deduped, max_truncated=1) + + if node.size() > 1: + trim2_delta = node.trim() + self.assertLess(trim2_delta, 0) + self.assertEqual(node.size(), total_size + trim_delta + trim2_delta) + self.assert_covers(node, everything_deduped, max_truncated=2) + + # Adding after trimming should be a no-op. + node_copy = copy.deepcopy(node) + for segments in everything: + self.assertEqual(node.add(segments), 0) + self.assertEqual(node, node_copy) + + # Merging after trimming should be a no-op. + self.assertEqual(node.merge(node0), 0) + self.assertEqual(node.merge(node1), 0) + self.assertEqual(node, node_copy) + + if node._truncated: + expected_delta = 0 + else: + expected_delta = 2 + + # Adding something new is not. + new_values = [('new1', ), ('new2', 'new2.1')] + self.assertEqual(node.add_all(new_values), expected_delta) + self.assert_covers( + node, list(everything_deduped) + new_values, max_truncated=2) + + # Nor is merging something new. + new_values_node = _BoundedTrieNode() + new_values_node.add_all(new_values) + self.assertEqual(node_copy.merge(new_values_node), expected_delta) + self.assert_covers( + node_copy, list(everything_deduped) + new_values, max_truncated=2) + + def run_fuzz(self, iterations=10, **params): + for _ in range(iterations): + seed = random.getrandbits(64) + segments = self.random_segments(**params, rand=random.Random(seed)) + try: + self.run_test(segments) + except: + print("SEED", seed) + raise + + def test_trivial(self): + self.run_test([('a', 'b'), ('a', 'c')]) + + def test_flat(self): + self.run_test([('a', 'a'), ('b', 'b'), ('c', 'c')]) + + def test_deep(self): + self.run_test([('a', ) * 10, ('b', ) * 12]) + + def test_small(self): + self.run_fuzz(n=5, min_depth=2, max_depth=3, overlap=0.5) + + def test_medium(self): + self.run_fuzz(n=20, min_depth=2, max_depth=4, overlap=0.5) + + def test_large_sparse(self): + self.run_fuzz(n=120, min_depth=2, max_depth=4, overlap=0.2) + + def test_large_dense(self): + self.run_fuzz(n=120, min_depth=2, max_depth=4, overlap=0.8) + + def test_bounded_trie_data_combine(self): + empty = BoundedTrieData() + # The merging here isn't complicated we're just ensuring that + # BoundedTrieData invokes _BoundedTrieNode correctly. + singletonA = BoundedTrieData(singleton=('a', 'a')) + singletonB = BoundedTrieData(singleton=('b', 'b')) + lots_root = _BoundedTrieNode() + lots_root.add_all([('c', 'c'), ('d', 'd')]) + lots = BoundedTrieData(root=lots_root) + self.assertEqual(empty.get_result(), set()) + self.assertEqual( + empty.combine(singletonA).get_result(), set([('a', 'a', False)])) + self.assertEqual( + singletonA.combine(empty).get_result(), set([('a', 'a', False)])) + self.assertEqual( + singletonA.combine(singletonB).get_result(), + set([('a', 'a', False), ('b', 'b', False)])) + self.assertEqual( + singletonA.combine(lots).get_result(), + set([('a', 'a', False), ('c', 'c', False), ('d', 'd', False)])) + self.assertEqual( + lots.combine(singletonA).get_result(), + set([('a', 'a', False), ('c', 'c', False), ('d', 'd', False)])) + + def test_bounded_trie_data_combine_trim(self): + left = _BoundedTrieNode() + left.add_all([('a', 'x'), ('b', 'd')]) + right = _BoundedTrieNode() + right.add_all([('a', 'y'), ('c', 'd')]) + self.assertEqual( + BoundedTrieData(root=left).combine( + BoundedTrieData(root=right, bound=3)).get_result(), + set([('a', True), ('b', 'd', False), ('c', 'd', False)])) + + def test_merge_on_empty_node(self): + root1 = _BoundedTrieNode() + root2 = _BoundedTrieNode() + root2.add_all([["a", "b", "c"], ["a", "b", "d"], ["a", "e"]]) + self.assertEqual(2, root1.merge(root2)) + self.assertEqual(3, root1.size()) + self.assertFalse(root1._truncated) + + def test_merge_with_empty_node(self): + root1 = _BoundedTrieNode() + root1.add_all([["a", "b", "c"], ["a", "b", "d"], ["a", "e"]]) + root2 = _BoundedTrieNode() + + self.assertEqual(0, root1.merge(root2)) + self.assertEqual(3, root1.size()) + self.assertFalse(root1._truncated) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py index fa70d3a4d9c0..c28c8340a505 100644 --- a/sdks/python/apache_beam/metrics/execution.py +++ b/sdks/python/apache_beam/metrics/execution.py @@ -43,6 +43,7 @@ from typing import cast from apache_beam.metrics import monitoring_infos +from apache_beam.metrics.cells import BoundedTrieCell from apache_beam.metrics.cells import CounterCell from apache_beam.metrics.cells import DistributionCell from apache_beam.metrics.cells import GaugeCell @@ -52,6 +53,7 @@ from apache_beam.runners.worker.statesampler import get_current_tracker if TYPE_CHECKING: + from apache_beam.metrics.cells import BoundedTrieData from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import DistributionData from apache_beam.metrics.cells import MetricCell @@ -265,6 +267,9 @@ def get_string_set(self, metric_name): StringSetCell, self.get_metric_cell(_TypedMetricName(StringSetCell, metric_name))) + def get_bounded_trie(self, metric_name): + return self.get_metric_cell(_TypedMetricName(BoundedTrieCell, metric_name)) + def get_metric_cell(self, typed_metric_name): # type: (_TypedMetricName) -> MetricCell cell = self.metrics.get(typed_metric_name, None) @@ -304,7 +309,14 @@ def get_cumulative(self): v in self.metrics.items() if k.cell_type == StringSetCell } - return MetricUpdates(counters, distributions, gauges, string_sets) + bounded_tries = { + MetricKey(self.step_name, k.metric_name): v.get_cumulative() + for k, + v in self.metrics.items() if k.cell_type == BoundedTrieCell + } + + return MetricUpdates( + counters, distributions, gauges, string_sets, bounded_tries) def to_runner_api(self): return [ @@ -358,6 +370,7 @@ def __init__( distributions=None, # type: Optional[Dict[MetricKey, DistributionData]] gauges=None, # type: Optional[Dict[MetricKey, GaugeData]] string_sets=None, # type: Optional[Dict[MetricKey, StringSetData]] + bounded_tries=None, # type: Optional[Dict[MetricKey, BoundedTrieData]] ): # type: (...) -> None @@ -368,8 +381,10 @@ def __init__( distributions: Dictionary of MetricKey:MetricUpdate objects. gauges: Dictionary of MetricKey:MetricUpdate objects. string_sets: Dictionary of MetricKey:MetricUpdate objects. + bounded_tries: Dictionary of MetricKey:MetricUpdate objects. """ self.counters = counters or {} self.distributions = distributions or {} self.gauges = gauges or {} self.string_sets = string_sets or {} + self.bounded_tries = bounded_tries or {} diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index 3e665dd805ea..9cf42370f4b1 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -33,6 +33,7 @@ from typing import Dict from typing import FrozenSet from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import Set @@ -42,6 +43,7 @@ from apache_beam.metrics import cells from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.execution import MetricUpdater +from apache_beam.metrics.metricbase import BoundedTrie from apache_beam.metrics.metricbase import Counter from apache_beam.metrics.metricbase import Distribution from apache_beam.metrics.metricbase import Gauge @@ -135,6 +137,22 @@ def string_set( namespace = Metrics.get_namespace(namespace) return Metrics.DelegatingStringSet(MetricName(namespace, name)) + @staticmethod + def bounded_trie( + namespace: Union[Type, str], + name: str) -> 'Metrics.DelegatingBoundedTrie': + """Obtains or creates a Bounded Trie metric. + + Args: + namespace: A class or string that gives the namespace to a metric + name: A string that gives a unique name to a metric + + Returns: + A BoundedTrie object. + """ + namespace = Metrics.get_namespace(namespace) + return Metrics.DelegatingBoundedTrie(MetricName(namespace, name)) + class DelegatingCounter(Counter): """Metrics Counter that Delegates functionality to MetricsEnvironment.""" def __init__( @@ -164,12 +182,19 @@ def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[method-assign] + class DelegatingBoundedTrie(BoundedTrie): + """Metrics StringSet that Delegates functionality to MetricsEnvironment.""" + def __init__(self, metric_name: MetricName) -> None: + super().__init__(metric_name) + self.add = MetricUpdater(cells.BoundedTrieCell, metric_name) # type: ignore[method-assign] + class MetricResults(object): COUNTERS = "counters" DISTRIBUTIONS = "distributions" GAUGES = "gauges" STRINGSETS = "string_sets" + BOUNDED_TRIES = "bounded_tries" @staticmethod def _matches_name(filter: 'MetricsFilter', metric_key: 'MetricKey') -> bool: @@ -318,8 +343,8 @@ class Lineage: SINK = "sinks" _METRICS = { - SOURCE: Metrics.string_set(LINEAGE_NAMESPACE, SOURCE), - SINK: Metrics.string_set(LINEAGE_NAMESPACE, SINK) + SOURCE: Metrics.bounded_trie(LINEAGE_NAMESPACE, SOURCE), + SINK: Metrics.bounded_trie(LINEAGE_NAMESPACE, SINK) } def __init__(self, label: str) -> None: @@ -368,8 +393,32 @@ def get_fq_name( return ':'.join((system, subtype, segs)) return ':'.join((system, segs)) + @staticmethod + def _get_fqn_parts( + system: str, + *segments: str, + subtype: Optional[str] = None, + last_segment_sep: Optional[str] = None) -> Iterator[str]: + yield system + ':' + if subtype: + yield subtype + ':' + if segments: + for segment in segments[:-1]: + yield segment + '.' + if last_segment_sep: + sub_segments = segments[-1].split(last_segment_sep) + for sub_segment in sub_segments[:-1]: + yield sub_segment + last_segment_sep + yield sub_segments[-1] + else: + yield segments[-1] + def add( - self, system: str, *segments: str, subtype: Optional[str] = None) -> None: + self, + system: str, + *segments: str, + subtype: Optional[str] = None, + last_segment_sep: Optional[str] = None) -> None: """ Adds the given details as Lineage. @@ -390,21 +439,35 @@ def add( The first positional argument serves as system, if full segments are provided, or the full FQN if it is provided as a single argument. """ - system_or_details = system - if len(segments) == 0 and subtype is None: - self.metric.add(system_or_details) - else: - self.metric.add(self.get_fq_name(system, *segments, subtype=subtype)) + self.add_raw( + *self._get_fqn_parts( + system, + *segments, + subtype=subtype, + last_segment_sep=last_segment_sep)) + + def add_raw(self, *rollup_segments: str) -> None: + """Adds the given fqn as lineage. + + `rollup_segments` should be an iterable of strings whose concatenation + is a valid Dataplex FQN. In particular, this means they will often have + trailing delimiters. + """ + self.metric.add(rollup_segments) @staticmethod - def query(results: MetricResults, label: str) -> Set[str]: + def query(results: MetricResults, + label: str, + truncated_marker: str = '*') -> Set[str]: if not label in Lineage._METRICS: raise ValueError("Label {} does not exist for Lineage", label) response = results.query( MetricsFilter().with_namespace(Lineage.LINEAGE_NAMESPACE).with_name( - label))[MetricResults.STRINGSETS] + label))[MetricResults.BOUNDED_TRIES] result = set() for metric in response: - result.update(metric.committed) - result.update(metric.attempted) + for fqn in metric.committed.flattened(): + result.add(''.join(fqn[:-1]) + (truncated_marker if fqn[-1] else '')) + for fqn in metric.attempted.flattened(): + result.add(''.join(fqn[:-1]) + (truncated_marker if fqn[-1] else '')) return result diff --git a/sdks/python/apache_beam/metrics/metric_test.py b/sdks/python/apache_beam/metrics/metric_test.py index 524a2143172d..2e2e51b267a7 100644 --- a/sdks/python/apache_beam/metrics/metric_test.py +++ b/sdks/python/apache_beam/metrics/metric_test.py @@ -271,14 +271,19 @@ def test_fq_name(self): def test_add(self): lineage = Lineage(Lineage.SOURCE) - stringset = set() + added = set() # override - lineage.metric = stringset + lineage.metric = added lineage.add("s", "1", "2") lineage.add("s:3.4") lineage.add("s", "5", "6.7") lineage.add("s", "1", "2", subtype="t") - self.assertSetEqual(stringset, {"s:1.2", "s:3.4", "s:t:1.2", "s:5.`6.7`"}) + lineage.add("sys", "seg1", "seg2", "seg3/part2/part3", last_segment_sep='/') + self.assertSetEqual( + added, + {('s:', '1.', '2'), ('s:3.4:', ), ('s:', '5.', '6.7'), + ('s:', 't:', '1.', '2'), + ('sys:', 'seg1.', 'seg2.', 'seg3/', 'part2/', 'part3')}) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/metrics/metricbase.py b/sdks/python/apache_beam/metrics/metricbase.py index 7819dbb093a5..9b35bb24f895 100644 --- a/sdks/python/apache_beam/metrics/metricbase.py +++ b/sdks/python/apache_beam/metrics/metricbase.py @@ -43,6 +43,7 @@ 'Distribution', 'Gauge', 'StringSet', + 'BoundedTrie', 'Histogram', 'MetricName' ] @@ -152,6 +153,14 @@ def add(self, value): raise NotImplementedError +class BoundedTrie(Metric): + """BoundedTrie Metric interface. + + Reports set of unique string values during pipeline execution..""" + def add(self, value): + raise NotImplementedError + + class Histogram(Metric): """Histogram Metric interface. diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py b/sdks/python/apache_beam/metrics/monitoring_infos.py index 5227a4c9872b..cb4e60e218f6 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos.py @@ -27,6 +27,7 @@ from apache_beam.coders import coder_impl from apache_beam.coders import coders +from apache_beam.metrics.cells import BoundedTrieData from apache_beam.metrics.cells import DistributionData from apache_beam.metrics.cells import DistributionResult from apache_beam.metrics.cells import GaugeData @@ -50,11 +51,14 @@ common_urns.monitoring_info_specs.USER_DISTRIBUTION_INT64.spec.urn) USER_GAUGE_URN = common_urns.monitoring_info_specs.USER_LATEST_INT64.spec.urn USER_STRING_SET_URN = common_urns.monitoring_info_specs.USER_SET_STRING.spec.urn +USER_BOUNDED_TRIE_URN = ( + common_urns.monitoring_info_specs.USER_BOUNDED_TRIE.spec.urn) USER_METRIC_URNS = set([ USER_COUNTER_URN, USER_DISTRIBUTION_URN, USER_GAUGE_URN, - USER_STRING_SET_URN + USER_STRING_SET_URN, + USER_BOUNDED_TRIE_URN, ]) WORK_REMAINING_URN = common_urns.monitoring_info_specs.WORK_REMAINING.spec.urn WORK_COMPLETED_URN = common_urns.monitoring_info_specs.WORK_COMPLETED.spec.urn @@ -72,11 +76,13 @@ LATEST_INT64_TYPE = common_urns.monitoring_info_types.LATEST_INT64_TYPE.urn PROGRESS_TYPE = common_urns.monitoring_info_types.PROGRESS_TYPE.urn STRING_SET_TYPE = common_urns.monitoring_info_types.SET_STRING_TYPE.urn +BOUNDED_TRIE_TYPE = common_urns.monitoring_info_types.BOUNDED_TRIE_TYPE.urn COUNTER_TYPES = set([SUM_INT64_TYPE]) DISTRIBUTION_TYPES = set([DISTRIBUTION_INT64_TYPE]) GAUGE_TYPES = set([LATEST_INT64_TYPE]) STRING_SET_TYPES = set([STRING_SET_TYPE]) +BOUNDED_TRIE_TYPES = set([BOUNDED_TRIE_TYPE]) # TODO(migryz) extract values from beam_fn_api.proto::MonitoringInfoLabels PCOLLECTION_LABEL = ( @@ -163,6 +169,14 @@ def extract_string_set_value(monitoring_info_proto): return set(coder.decode(monitoring_info_proto.payload)) +def extract_bounded_trie_value(monitoring_info_proto): + if not is_bounded_trie(monitoring_info_proto): + raise ValueError('Unsupported type %s' % monitoring_info_proto.type) + + return BoundedTrieData.from_proto( + metrics_pb2.BoundedTrie.FromString(monitoring_info_proto.payload)) + + def create_labels(ptransform=None, namespace=None, name=None, pcollection=None): """Create the label dictionary based on the provided values. @@ -320,6 +334,23 @@ def user_set_string(namespace, name, metric, ptransform=None): USER_STRING_SET_URN, STRING_SET_TYPE, metric, labels) +def user_bounded_trie(namespace, name, metric, ptransform=None): + """Return the string set monitoring info for the URN, metric and labels. + + Args: + namespace: User-defined namespace of BoundedTrie. + name: Name of BoundedTrie. + metric: The BoundedTrieData representing the metrics. + ptransform: The ptransform id used as a label. + """ + labels = create_labels(ptransform=ptransform, namespace=namespace, name=name) + return create_monitoring_info( + USER_BOUNDED_TRIE_URN, + BOUNDED_TRIE_TYPE, + metric.to_proto().SerializeToString(), + labels) + + def create_monitoring_info( urn, type_urn, payload, labels=None) -> metrics_pb2.MonitoringInfo: """Return the gauge monitoring info for the URN, type, metric and labels. @@ -360,6 +391,11 @@ def is_string_set(monitoring_info_proto): return monitoring_info_proto.type in STRING_SET_TYPES +def is_bounded_trie(monitoring_info_proto): + """Returns true if the monitoring info is a BoundedTrie metric.""" + return monitoring_info_proto.type in BOUNDED_TRIE_TYPES + + def is_user_monitoring_info(monitoring_info_proto): """Returns true if the monitoring info is a user metric.""" return monitoring_info_proto.urn in USER_METRIC_URNS @@ -367,7 +403,7 @@ def is_user_monitoring_info(monitoring_info_proto): def extract_metric_result_map_value( monitoring_info_proto -) -> Union[None, int, DistributionResult, GaugeResult, set]: +) -> Union[None, int, DistributionResult, GaugeResult, set, BoundedTrieData]: """Returns the relevant GaugeResult, DistributionResult or int value for counter metric, set for StringSet metric. @@ -385,6 +421,8 @@ def extract_metric_result_map_value( return GaugeResult(GaugeData(value, timestamp)) if is_string_set(monitoring_info_proto): return extract_string_set_value(monitoring_info_proto) + if is_bounded_trie(monitoring_info_proto): + return extract_bounded_trie_value(monitoring_info_proto) return None diff --git a/sdks/python/apache_beam/ml/rag/__init__.py b/sdks/python/apache_beam/ml/rag/__init__.py new file mode 100644 index 000000000000..554beb9d7aba --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/__init__.py @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Apache Beam RAG (Retrieval Augmented Generation) components. +This package provides components for building RAG pipelines in Apache Beam, +including: +- Chunking +- Embedding generation +- Vector storage +- Vector search enrichment +""" diff --git a/sdks/python/apache_beam/ml/rag/chunking/__init__.py b/sdks/python/apache_beam/ml/rag/chunking/__init__.py new file mode 100644 index 000000000000..34a6a966b19e --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Chunking components for RAG pipelines. +This module provides components for splitting text into chunks for RAG +pipelines. +""" diff --git a/sdks/python/apache_beam/ml/rag/chunking/base.py b/sdks/python/apache_beam/ml/rag/chunking/base.py new file mode 100644 index 000000000000..626a6ea8abbe --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/base.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import functools +from collections.abc import Callable +from typing import Any +from typing import Dict +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.transforms.base import MLTransformProvider + +ChunkIdFn = Callable[[Chunk], str] + + +def _assign_chunk_id(chunk_id_fn: ChunkIdFn, chunk: Chunk): + chunk.id = chunk_id_fn(chunk) + return chunk + + +class ChunkingTransformProvider(MLTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): + """Base class for chunking transforms in RAG pipelines. + + ChunkingTransformProvider defines the interface for splitting documents + into chunks for embedding and retrieval. Implementations should define how + to split content while preserving metadata and managing chunk IDs. + + The transform flow: + - Takes input documents with content and metadata + - Splits content into chunks using implementation-specific logic + - Preserves document metadata in resulting chunks + - Optionally assigns unique IDs to chunks (configurable via chunk_id_fn + + Example usage: + >>> class MyChunker(ChunkingTransformProvider): + ... def get_splitter_transform(self): + ... return beam.ParDo(MySplitterDoFn()) + ... + >>> chunker = MyChunker(chunk_id_fn=my_id_function) + >>> + >>> with beam.Pipeline() as p: + ... chunks = ( + ... p + ... | beam.Create([{'text': 'document...', 'source': 'doc.txt'}]) + ... | MLTransform(...).with_transform(chunker)) + + Args: + chunk_id_fn: Optional function to generate chunk IDs. If not provided, + random UUIDs will be used. Function should take a Chunk and return str. + """ + self.assign_chunk_id_fn = functools.partial( + _assign_chunk_id, chunk_id_fn) if chunk_id_fn is not None else None + + @abc.abstractmethod + def get_splitter_transform( + self + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + """Creates transforms that emits splits for given content.""" + raise NotImplementedError( + "Subclasses must implement get_splitter_transform") + + def get_ptransform_for_processing( + self, **kwargs + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + """Creates transform for processing documents into chunks.""" + ptransform = ( + "Split document" >> + self.get_splitter_transform().with_output_types(Chunk)) + if self.assign_chunk_id_fn: + ptransform = ( + ptransform | "Assign chunk id" >> beam.Map( + self.assign_chunk_id_fn).with_output_types(Chunk)) + return ptransform diff --git a/sdks/python/apache_beam/ml/rag/chunking/base_test.py b/sdks/python/apache_beam/ml/rag/chunking/base_test.py new file mode 100644 index 000000000000..54e25591c348 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/base_test.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for apache_beam.ml.rag.chunking.base.""" + +import unittest +from typing import Any +from typing import Dict +from typing import Optional + +import pytest + +import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + + +class WordSplitter(beam.DoFn): + def process(self, element): + words = element['text'].split() + for i, word in enumerate(words): + yield Chunk( + content=Content(text=word), + index=i, + metadata={'source': element['source']}) + + +class InvalidChunkingProvider(ChunkingTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): + super().__init__(chunk_id_fn=chunk_id_fn) + + +class MockChunkingProvider(ChunkingTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): + super().__init__(chunk_id_fn=chunk_id_fn) + + def get_splitter_transform( + self + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + return beam.ParDo(WordSplitter()) + + +def chunk_equals(expected, actual): + """Custom equality function for Chunk objects.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + # Don't compare IDs since they're randomly generated + return ( + expected.index == actual.index and expected.content == actual.content and + expected.metadata == actual.metadata) + + +def id_equals(expected, actual): + """Custom equality function for Chunk object id's.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + return (expected.id == actual.id) + + +@pytest.mark.uses_transformers +class ChunkingTransformProviderTest(unittest.TestCase): + def setUp(self): + self.test_doc = {'text': 'hello world test', 'source': 'test.txt'} + + def test_doesnt_override_get_text_splitter_transform(self): + provider = InvalidChunkingProvider() + with self.assertRaises(NotImplementedError): + provider.get_splitter_transform() + + def test_chunking_transform(self): + """Test the complete chunking transform.""" + provider = MockChunkingProvider() + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.test_doc]) + | provider.get_ptransform_for_processing()) + + expected = [ + Chunk( + content=Content(text="hello"), + index=0, + metadata={'source': 'test.txt'}), + Chunk( + content=Content(text="world"), + index=1, + metadata={'source': 'test.txt'}), + Chunk( + content=Content(text="test"), + index=2, + metadata={'source': 'test.txt'}) + ] + + assert_that(chunks, equal_to(expected, equals_fn=chunk_equals)) + + def test_custom_chunk_id_fn(self): + """Test the a custom chink id function.""" + def source_index_id_fn(chunk: Chunk): + return f"{chunk.metadata['source']}_{chunk.index}" + + provider = MockChunkingProvider(chunk_id_fn=source_index_id_fn) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.test_doc]) + | provider.get_ptransform_for_processing()) + + expected = [ + Chunk(content=Content(text="hello"), id="test.txt_0"), + Chunk(content=Content(text="world"), id="test.txt_1"), + Chunk(content=Content(text="test"), id="test.txt_2") + ] + + assert_that(chunks, equal_to(expected, equals_fn=id_equals)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain.py b/sdks/python/apache_beam/ml/rag/chunking/langchain.py new file mode 100644 index 000000000000..9e3b6b0c8ef9 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content + +try: + from langchain.text_splitter import TextSplitter +except ImportError: + TextSplitter = None + + +class LangChainChunker(ChunkingTransformProvider): + def __init__( + self, + text_splitter: TextSplitter, + document_field: str, + metadata_fields: List[str], + chunk_id_fn: Optional[ChunkIdFn] = None): + """A ChunkingTransformProvider that uses LangChain text splitters. + + This provider integrates LangChain's text splitting capabilities into + Beam's MLTransform framework. It supports various text splitting strategies + through LangChain's TextSplitter interface, including recursive character + splitting and other methods. + + The provider: + - Takes documents with text content and metadata + - Splits text using configured LangChain splitter + - Preserves document metadata in resulting chunks + - Assigns unique IDs to chunks (configurable via chunk_id_fn) + + Example usage: + ```python + from langchain.text_splitter import RecursiveCharacterTextSplitter + + splitter = RecursiveCharacterTextSplitter( + chunk_size=100, + chunk_overlap=20 + ) + + chunker = LangChainChunker(text_splitter=splitter) + + with beam.Pipeline() as p: + chunks = ( + p + | beam.Create([{'text': 'long document...', 'source': 'doc.txt'}]) + | MLTransform(...).with_transform(chunker)) + ``` + + Args: + text_splitter: A LangChain TextSplitter instance that defines how + documents are split into chunks. + metadata_fields: List of field names to copy from input documents to + chunk metadata. These fields will be preserved in each chunk created + from the document. + chunk_id_fn: Optional function that take a Chunk and return str to + generate chunk IDs. If not provided, random UUIDs will be used. + """ + if not TextSplitter: + raise ImportError( + "langchain is required to use LangChainChunker" + "Please install it with using `pip install langchain`.") + if not isinstance(text_splitter, TextSplitter): + raise TypeError("text_splitter must be a LangChain TextSplitter") + if not document_field: + raise ValueError("document_field cannot be empty") + super().__init__(chunk_id_fn) + self.text_splitter = text_splitter + self.document_field = document_field + self.metadata_fields = metadata_fields + + def get_splitter_transform( + self + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + return "Langchain text split" >> beam.ParDo( + _LangChainTextSplitter( + text_splitter=self.text_splitter, + document_field=self.document_field, + metadata_fields=self.metadata_fields)) + + +class _LangChainTextSplitter(beam.DoFn): + def __init__( + self, + text_splitter: TextSplitter, + document_field: str, + metadata_fields: List[str]): + self.text_splitter = text_splitter + self.document_field = document_field + self.metadata_fields = metadata_fields + + def process(self, element): + text_chunks = self.text_splitter.split_text(element[self.document_field]) + metadata = {field: element[field] for field in self.metadata_fields} + for i, text_chunk in enumerate(text_chunks): + yield Chunk(content=Content(text=text_chunk), index=i, metadata=metadata) diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py new file mode 100644 index 000000000000..83a4fc1a778f --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py @@ -0,0 +1,217 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for apache_beam.ml.rag.chunking.langchain.""" + +import unittest + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +try: + from apache_beam.ml.rag.chunking.langchain import LangChainChunker + + from langchain.text_splitter import ( + CharacterTextSplitter, RecursiveCharacterTextSplitter) + LANGCHAIN_AVAILABLE = True +except ImportError: + LANGCHAIN_AVAILABLE = False + +# Import optional dependencies +try: + from transformers import AutoTokenizer + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + + +def chunk_equals(expected, actual): + """Custom equality function for Chunk objects.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + return ( + expected.content == actual.content and expected.index == actual.index and + expected.metadata == actual.metadata) + + +@unittest.skipIf(not LANGCHAIN_AVAILABLE, 'langchain is not installed.') +class LangChainChunkingTest(unittest.TestCase): + def setUp(self): + self.simple_text = { + 'content': 'This is a simple test document. It has multiple sentences. ' + 'We will use it to test basic splitting.', + 'source': 'simple.txt', + 'language': 'en' + } + + self.complex_text = { + 'content': ( + 'The patient arrived at 2 p.m. yesterday. ' + 'Initial assessment was completed. ' + 'Lab results showed normal ranges. ' + 'Follow-up scheduled for next week.'), + 'source': 'medical.txt', + 'language': 'en' + } + + def test_no_metadata_fields(self): + """Test chunking with no metadata fields specified.""" + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + provider = LangChainChunker( + document_field='content', metadata_fields=[], text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + chunks_count = chunks | beam.combiners.Count.Globally() + + assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + + assert_that(chunks, lambda x: all(c.metadata == {} for c in x)) + + def test_multiple_metadata_fields(self): + """Test chunking with multiple metadata fields.""" + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + provider = LangChainChunker( + document_field='content', + metadata_fields=['source', 'language'], + text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + chunks_count = chunks | beam.combiners.Count.Globally() + + assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that( + chunks, + lambda x: all( + c.metadata == { + 'source': 'simple.txt', 'language': 'en' + } for c in x)) + + def test_recursive_splitter_no_overlap(self): + """Test RecursiveCharacterTextSplitter with no overlap.""" + splitter = RecursiveCharacterTextSplitter( + chunk_size=30, chunk_overlap=0, separators=[". "]) + provider = LangChainChunker( + document_field='content', + metadata_fields=['source'], + text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + chunks_count = chunks | beam.combiners.Count.Globally() + + assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that(chunks, lambda x: all(len(c.content.text) <= 30 for c in x)) + + @unittest.skipIf(not TRANSFORMERS_AVAILABLE, "transformers not available") + def test_huggingface_tokenizer_splitter(self): + """Test text splitter created from HuggingFace tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + tokenizer, + chunk_size=10, # tokens + chunk_overlap=2 # tokens + ) + + provider = LangChainChunker( + document_field='content', + metadata_fields=['source'], + text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + + def check_token_lengths(chunks): + for chunk in chunks: + # Verify each chunk's token length is within limits + num_tokens = len(tokenizer.encode(chunk.content.text)) + if not num_tokens <= 10: + raise AssertionError( + f"Chunk has {num_tokens} tokens, expected <= 10") + return True + + chunks_count = chunks | beam.combiners.Count.Globally() + + assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that(chunks, check_token_lengths) + + def test_invalid_document_field(self): + """Test that using an invalid document field raises KeyError.""" + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + provider = LangChainChunker( + document_field='nonexistent', + metadata_fields={}, + text_splitter=splitter) + + with self.assertRaises(KeyError): + with TestPipeline() as p: + _ = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + + def test_empty_document_field(self): + """Test that using an invalid document field raises KeyError.""" + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + + with self.assertRaises(ValueError): + _ = LangChainChunker( + document_field='', metadata_fields={}, text_splitter=splitter) + + def test_invalid_text_splitter(self): + """Test that using an invalid document field raises KeyError.""" + + with self.assertRaises(TypeError): + _ = LangChainChunker( + document_field='nonexistent', text_splitter="Not a text splitter!") + + def test_empty_text(self): + """Test that empty text produces no chunks.""" + empty_doc = {'content': '', 'source': 'empty.txt'} + + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + provider = LangChainChunker( + document_field='content', + metadata_fields=['source'], + text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([empty_doc]) + | provider.get_ptransform_for_processing()) + + assert_that(chunks, equal_to([])) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/embeddings/__init__.py b/sdks/python/apache_beam/ml/rag/embeddings/__init__.py new file mode 100644 index 000000000000..d2cdb63c0bde --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Embedding components for RAG pipelines. +This module provides components for generating embeddings in RAG pipelines. +""" diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base.py b/sdks/python/apache_beam/ml/rag/embeddings/base.py new file mode 100644 index 000000000000..25dc3ee47e80 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/base.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import List + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.transforms.base import EmbeddingTypeAdapter + + +def create_rag_adapter() -> EmbeddingTypeAdapter[Chunk, Chunk]: + """Creates adapter for converting between Chunk and Embedding types. + + The adapter: + - Extracts text from Chunk.content.text for embedding + - Creates Embedding objects from model output + - Sets Embedding in Chunk.embedding + + Returns: + EmbeddingTypeAdapter configured for RAG pipeline types + """ + return EmbeddingTypeAdapter( + input_fn=_extract_chunk_text, output_fn=_add_embedding_fn) + + +def _extract_chunk_text(chunks: Sequence[Chunk]) -> List[str]: + """Extract text from chunks for embedding.""" + chunk_texts = [] + for chunk in chunks: + if not chunk.content.text: + raise ValueError("Expected chunk text content.") + chunk_texts.append(chunk.content.text) + return chunk_texts + + +def _add_embedding_fn( + chunks: Sequence[Chunk], embeddings: Sequence[List[float]]) -> List[Chunk]: + """Create Embeddings from chunks and embedding vectors.""" + for chunk, embedding in zip(chunks, embeddings): + chunk.embedding = Embedding(dense_embedding=embedding) + return list(chunks) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py new file mode 100644 index 000000000000..3a27ae8e7ebb --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding + + +class RAGBaseEmbeddingsTest(unittest.TestCase): + def setUp(self): + self.test_chunks = [ + Chunk( + content=Content(text="This is a test sentence."), + id="1", + metadata={ + "source": "test.txt", "language": "en" + }), + Chunk( + content=Content(text="Another example."), + id="2", + metadata={ + "source": "test2.txt", "language": "en" + }) + ] + + def test_adapter_input_conversion(self): + """Test the RAG type adapter converts correctly.""" + adapter = create_rag_adapter() + + # Test input conversion + texts = adapter.input_fn(self.test_chunks) + self.assertEqual(texts, ["This is a test sentence.", "Another example."]) + + def test_adapter_input_conversion_missing_text_content(self): + """Test the RAG type adapter converts correctly.""" + adapter = create_rag_adapter() + + # Test input conversion + with self.assertRaisesRegex(ValueError, "Expected chunk text content"): + adapter.input_fn([ + Chunk( + content=Content(), + id="1", + metadata={ + "source": "test.txt", "language": "en" + }) + ]) + + def test_adapter_output_conversion(self): + """Test the RAG type adapter converts correctly.""" + # Test output conversion + mock_embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + # Expected outputs + expected = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + metadata={ + 'source': 'test.txt', 'language': 'en' + }, + content=Content(text='This is a test sentence.')), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.4, 0.5, 0.6]), + metadata={ + 'source': 'test2.txt', 'language': 'en' + }, + content=Content(text='Another example.')), + ] + adapter = create_rag_adapter() + + embeddings = adapter.output_fn(self.test_chunks, mock_embeddings) + self.assertListEqual(embeddings, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py new file mode 100644 index 000000000000..4cb0aecd6e82 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RAG-specific embedding implementations using HuggingFace models.""" + +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler + +try: + from sentence_transformers import SentenceTransformer +except ImportError: + SentenceTransformer = None + + +class HuggingfaceTextEmbeddings(EmbeddingsManager): + def __init__( + self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs): + """Utilizes huggingface SentenceTransformer embeddings for RAG pipeline. + + Args: + model_name: Name of the sentence-transformers model to use + max_seq_length: Maximum sequence length for the model + **kwargs: Additional arguments passed to + :class:`~apache_beam.ml.transforms.base.EmbeddingsManager` + constructor including ModelHandler arguments + """ + if not SentenceTransformer: + raise ImportError( + "sentence-transformers is required to use " + "HuggingfaceTextEmbeddings." + "Please install it with using `pip install sentence-transformers`.") + super().__init__(type_adapter=create_rag_adapter(), **kwargs) + self.model_name = model_name + self.max_seq_length = max_seq_length + self.model_class = SentenceTransformer + + def get_model_handler(self): + """Returns model handler configured with RAG adapter.""" + return _SentenceTransformerModelHandler( + model_class=self.model_class, + max_seq_length=self.max_seq_length, + model_name=self.model_name, + load_model_args=self.load_model_args, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + large_model=self.large_model) + + def get_ptransform_for_processing( + self, **kwargs + ) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]: + """Returns PTransform that uses the RAG adapter.""" + return RunInference( + model_handler=_TextEmbeddingHandler(self), + inference_args=self.inference_args).with_output_types(Chunk) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py new file mode 100644 index 000000000000..aa63d13025a1 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for apache_beam.ml.rag.embeddings.huggingface.""" + +import tempfile +import unittest + +import pytest + +import apache_beam as beam +from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=unused-import +try: + from sentence_transformers import SentenceTransformer + SENTENCE_TRANSFORMERS_AVAILABLE = True +except ImportError: + SENTENCE_TRANSFORMERS_AVAILABLE = False + + +def chunk_approximately_equals(expected, actual): + """Compare embeddings allowing for numerical differences.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + + return ( + expected.id == actual.id and expected.metadata == actual.metadata and + expected.content == actual.content and + len(expected.embedding.dense_embedding) == len( + actual.embedding.dense_embedding) and + all(isinstance(x, float) for x in actual.embedding.dense_embedding)) + + +@pytest.mark.uses_transformers +@unittest.skipIf( + not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available") +class HuggingfaceTextEmbeddingsTest(unittest.TestCase): + def setUp(self): + self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_') + self.test_chunks = [ + Chunk( + content=Content(text="This is a test sentence."), + id="1", + metadata={ + "source": "test.txt", "language": "en" + }), + Chunk( + content=Content(text="Another example."), + id="2", + metadata={ + "source": "test.txt", "language": "en" + }) + ] + + def test_embedding_pipeline(self): + expected = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.0] * 384), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="This is a test sentence.")), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.0] * 384), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="Another example.")) + ] + embedder = HuggingfaceTextEmbeddings( + model_name="sentence-transformers/all-MiniLM-L6-v2") + + with TestPipeline() as p: + embeddings = ( + p + | beam.Create(self.test_chunks) + | MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedder)) + + assert_that( + embeddings, equal_to(expected, equals_fn=chunk_approximately_equals)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/types.py b/sdks/python/apache_beam/ml/rag/types.py new file mode 100644 index 000000000000..79429899e4c1 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/types.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Core types for RAG pipelines. +This module contains the core dataclasses used throughout the RAG pipeline +implementation, including Chunk and Embedding types that define the data +contracts between different stages of the pipeline. +""" + +import uuid +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + + +@dataclass +class Content: + """Container for embeddable content. Add new types as when as necessary. + + Args: + text: Text content to be embedded + """ + text: Optional[str] = None + + +@dataclass +class Embedding: + """Represents vector embeddings. + + Args: + dense_embedding: Dense vector representation + sparse_embedding: Optional sparse vector representation for hybrid + search + """ + dense_embedding: Optional[List[float]] = None + # For hybrid search + sparse_embedding: Optional[Tuple[List[int], List[float]]] = None + + +@dataclass +class Chunk: + """Represents a chunk of embeddable content with metadata. + + Args: + content: The actual content of the chunk + id: Unique identifier for the chunk + index: Index of this chunk within the original document + metadata: Additional metadata about the chunk (e.g., document source) + embedding: Vector embeddings of the content + """ + content: Content + id: str = field(default_factory=lambda: str(uuid.uuid4())) + index: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + embedding: Optional[Embedding] = None diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index a963f602a06d..57a5efd3ff0e 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -15,18 +15,24 @@ # limitations under the License. import abc -import collections +import functools import logging import os import tempfile import uuid +from collections.abc import Callable from collections.abc import Mapping from collections.abc import Sequence +from dataclasses import dataclass from typing import Any +from typing import Dict from typing import Generic +from typing import Iterable +from typing import List from typing import Optional from typing import TypeVar from typing import Union +from typing import cast import jsonpickle import numpy as np @@ -62,36 +68,31 @@ # Output of the apply() method of BaseOperation. OperationOutputT = TypeVar('OperationOutputT') +# Input to the EmbeddingTypeAdapter input_fn +EmbeddingTypeAdapterInputT = TypeVar( + 'EmbeddingTypeAdapterInputT') # e.g., Chunk +# Output of the EmbeddingTypeAdapter output_fn +EmbeddingTypeAdapterOutputT = TypeVar( + 'EmbeddingTypeAdapterOutputT') # e.g., Embedding -def _convert_list_of_dicts_to_dict_of_lists( - list_of_dicts: Sequence[dict[str, Any]]) -> dict[str, list[Any]]: - keys_to_element_list = collections.defaultdict(list) - input_keys = list_of_dicts[0].keys() - for d in list_of_dicts: - if set(d.keys()) != set(input_keys): - extra_keys = set(d.keys()) - set(input_keys) if len( - d.keys()) > len(input_keys) else set(input_keys) - set(d.keys()) - raise RuntimeError( - f'All the dicts in the input data should have the same keys. ' - f'Got: {extra_keys} instead.') - for key, value in d.items(): - keys_to_element_list[key].append(value) - return keys_to_element_list - - -def _convert_dict_of_lists_to_lists_of_dict( - dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]: - batch_length = len(next(iter(dict_of_lists.values()))) - result: list[dict[str, Any]] = [{} for _ in range(batch_length)] - # all the values in the dict_of_lists should have same length - for key, values in dict_of_lists.items(): - assert len(values) == batch_length, ( - "This function expects all the values " - "in the dict_of_lists to have same length." - ) - for i in range(len(values)): - result[i][key] = values[i] - return result + +@dataclass +class EmbeddingTypeAdapter(Generic[EmbeddingTypeAdapterInputT, + EmbeddingTypeAdapterOutputT]): + """Adapts input types to text for embedding and converts output embeddings. + + Args: + input_fn: Function to extract text for embedding from input type + output_fn: Function to create output type from input and embeddings + """ + input_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT]], List[str]] + output_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT], Sequence[Any]], + List[EmbeddingTypeAdapterOutputT]] + + def __reduce__(self): + """Custom serialization that preserves type information during + jsonpickle.""" + return (self.__class__, (self.input_fn, self.output_fn)) def _map_errors_to_beam_row(element, cls_name=None): @@ -182,13 +183,74 @@ def append_transform(self, transform: BaseOperation): """ +def _dict_input_fn(columns: Sequence[str], + batch: Sequence[Dict[str, Any]]) -> List[str]: + """Extract text from specified columns in batch.""" + if not batch or not isinstance(batch[0], dict): + raise TypeError( + 'Expected data to be dicts, got ' + f'{type(batch[0])} instead.') + + result = [] + expected_keys = set(batch[0].keys()) + expected_columns = set(columns) + # Process one batch item at a time + for item in batch: + item_keys = item.keys() + if set(item_keys) != expected_keys: + extra_keys = item_keys - expected_keys + missing_keys = expected_keys - item_keys + raise RuntimeError( + f'All dicts in batch must have the same keys. ' + f'extra keys: {extra_keys}, ' + f'missing keys: {missing_keys}') + missing_columns = expected_columns - item_keys + if (missing_columns): + raise RuntimeError( + f'Data does not contain the following columns ' + f': {missing_columns}.') + + # Get all columns for this item + for col in columns: + result.append(item[col]) + return result + + +def _dict_output_fn( + columns: Sequence[str], + batch: Sequence[Dict[str, Any]], + embeddings: Sequence[Any]) -> List[Dict[str, Any]]: + """Map embeddings back to columns in batch.""" + result = [] + for batch_idx, item in enumerate(batch): + for col_idx, col in enumerate(columns): + embedding_idx = batch_idx * len(columns) + col_idx + item[col] = embeddings[embedding_idx] + result.append(item) + return result + + +def _create_dict_adapter( + columns: List[str]) -> EmbeddingTypeAdapter[Dict[str, Any], Dict[str, Any]]: + """Create adapter for dict-based processing.""" + return EmbeddingTypeAdapter[Dict[str, Any], Dict[str, Any]]( + input_fn=cast( + Callable[[Sequence[Dict[str, Any]]], List[str]], + functools.partial(_dict_input_fn, columns)), + output_fn=cast( + Callable[[Sequence[Dict[str, Any]], Sequence[Any]], + List[Dict[str, Any]]], + functools.partial(_dict_output_fn, columns))) + + # TODO:https://github.com/apache/beam/issues/29356 # Add support for inference_fn class EmbeddingsManager(MLTransformProvider): def __init__( self, - columns: list[str], *, + columns: Optional[list[str]] = None, + type_adapter: Optional[EmbeddingTypeAdapter] = None, # common args for all ModelHandlers. load_model_args: Optional[dict[str, Any]] = None, min_batch_size: Optional[int] = None, @@ -200,6 +262,12 @@ def __init__( self.max_batch_size = max_batch_size self.large_model = large_model self.columns = columns + if columns is not None: + self.type_adapter = _create_dict_adapter(columns) + elif type_adapter is not None: + self.type_adapter = type_adapter + else: + raise ValueError("Either columns or type_adapter must be specified") self.inference_args = kwargs.pop('inference_args', {}) if kwargs: @@ -616,38 +684,6 @@ def load_model(self): def _validate_column_data(self, batch): pass - def _validate_batch(self, batch: Sequence[dict[str, Any]]): - if not batch or not isinstance(batch[0], dict): - raise TypeError( - 'Expected data to be dicts, got ' - f'{type(batch[0])} instead.') - - def _process_batch( - self, - dict_batch: dict[str, list[Any]], - model: ModelT, - inference_args: Optional[dict[str, Any]]) -> dict[str, list[Any]]: - result: dict[str, list[Any]] = collections.defaultdict(list) - input_keys = dict_batch.keys() - missing_columns_in_data = set(self.columns) - set(input_keys) - if missing_columns_in_data: - raise RuntimeError( - f'Data does not contain the following columns ' - f': {missing_columns_in_data}.') - for key, batch in dict_batch.items(): - if key in self.columns: - self._validate_column_data(batch) - prediction = self._underlying.run_inference( - batch, model, inference_args) - if isinstance(prediction, np.ndarray): - prediction = prediction.tolist() - result[key] = prediction # type: ignore[assignment] - else: - result[key] = prediction # type: ignore[assignment] - else: - result[key] = batch - return result - def run_inference( self, batch: Sequence[dict[str, list[str]]], @@ -659,12 +695,19 @@ def run_inference( a list of dicts. Each dict should have the same keys, and the shape should be of the same size for a single key across the batch. """ - self._validate_batch(batch) - dict_batch = _convert_list_of_dicts_to_dict_of_lists(list_of_dicts=batch) - transformed_batch = self._process_batch(dict_batch, model, inference_args) - return _convert_dict_of_lists_to_lists_of_dict( - dict_of_lists=transformed_batch, - ) + embedding_input = self.embedding_config.type_adapter.input_fn(batch) + self._validate_column_data(batch=embedding_input) + prediction = self._underlying.run_inference( + embedding_input, model, inference_args) + # Convert prediction to Sequence[Any] + if isinstance(prediction, np.ndarray): + prediction_seq = prediction.tolist() + elif isinstance(prediction, Iterable) and not isinstance(prediction, + (str, bytes)): + prediction_seq = list(prediction) + else: + prediction_seq = [prediction] + return self.embedding_config.type_adapter.output_fn(batch, prediction_seq) def get_metrics_namespace(self) -> str: return ( diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 3db5a63b9542..1ef01acca18a 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -78,6 +78,10 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) + def test_ml_transform_no_read_or_write_artifact_lcoation(self): + with self.assertRaises(ValueError): + _ = base.MLTransform(transforms=[]) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_appends_transforms_to_process_handler_correctly(self): fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x']) @@ -354,6 +358,21 @@ def __repr__(self): return 'FakeEmbeddingsManager' +class InvalidEmbeddingsManager(base.EmbeddingsManager): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_model_handler(self) -> ModelHandler: + InvalidEmbeddingsManager.__repr__ = lambda x: 'InvalidEmbeddingsManager' # type: ignore[method-assign] + return FakeModelHandler() + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return (RunInference(model_handler=base._TextEmbeddingHandler(self))) + + def __repr__(self): + return 'InvalidEmbeddingsManager' + + class TextEmbeddingHandlerTest(unittest.TestCase): def setUp(self) -> None: self.embedding_conig = FakeEmbeddingsManager(columns=['x']) @@ -362,6 +381,10 @@ def setUp(self) -> None: def tearDown(self) -> None: shutil.rmtree(self.artifact_location) + def test_no_columns_or_type_adapter(self): + with self.assertRaises(ValueError): + _ = InvalidEmbeddingsManager() + def test_handler_with_incompatible_datatype(self): text_handler = base._TextEmbeddingHandler( embeddings_manager=self.embedding_conig) @@ -430,9 +453,9 @@ def test_handler_on_multiple_columns(self): 'x': "Apache Beam", 'y': "Hello world", 'z': 'unchanged' }, ] - self.embedding_conig.columns = ['x', 'y'] + embedding_config = FakeEmbeddingsManager(columns=['x', 'y']) expected_data = [{ - key: (value[::-1] if key in self.embedding_conig.columns else value) + key: (value[::-1] if key in embedding_config.columns else value) for key, value in d.items() } for d in data] @@ -440,9 +463,8 @@ def test_handler_on_multiple_columns(self): result = ( p | beam.Create(data) - | base.MLTransform( - write_artifact_location=self.artifact_location).with_transform( - self.embedding_conig)) + | base.MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedding_config)) assert_that( result, equal_to(expected_data), @@ -457,16 +479,15 @@ def test_handler_on_columns_not_exist_in_input_data(self): 'x': "Apache Beam", 'y': "Hello world" }, ] - self.embedding_conig.columns = ['x', 'y', 'a'] + embedding_config = FakeEmbeddingsManager(columns=['x', 'y', 'a']) with self.assertRaises(RuntimeError): with beam.Pipeline() as p: _ = ( p | beam.Create(data) - | base.MLTransform( - write_artifact_location=self.artifact_location).with_transform( - self.embedding_conig)) + | base.MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedding_config)) def test_handler_with_list_data(self): data = [{ @@ -550,7 +571,7 @@ def tearDown(self) -> None: shutil.rmtree(self.artifact_location) @unittest.skipIf(PIL is None, 'PIL module is not installed.') - def test_handler_with_incompatible_datatype(self): + def test_handler_with_non_dict_datatype(self): image_handler = base._ImageEmbeddingHandler( embeddings_manager=self.embedding_config) data = [ @@ -561,6 +582,24 @@ def test_handler_with_incompatible_datatype(self): with self.assertRaises(TypeError): image_handler.run_inference(data, None, None) + @unittest.skipIf(PIL is None, 'PIL module is not installed.') + def test_handler_with_non_image_datatype(self): + image_handler = base._ImageEmbeddingHandler( + embeddings_manager=self.embedding_config) + data = [ + { + 'x': 'hi there' + }, + { + 'x': 'not an image' + }, + { + 'x': 'image_path.jpg' + }, + ] + with self.assertRaises(TypeError): + image_handler.run_inference(data, None, None) + @unittest.skipIf(PIL is None, 'PIL module is not installed.') def test_handler_with_dict_inputs(self): img_one = PIL.Image.new(mode='RGB', size=(1, 1)) @@ -588,31 +627,37 @@ def test_handler_with_dict_inputs(self): class TestUtilFunctions(unittest.TestCase): - def test_list_of_dicts_to_dict_of_lists_normal(self): + def test_dict_input_fn_normal(self): + input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] + columns = ['a', 'b'] + + expected_output = [1, 2, 3, 4] + self.assertEqual(base._dict_input_fn(columns, input_list), expected_output) + + def test_dict_output_fn_normal(self): input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] - expected_output = {'a': [1, 3], 'b': [2, 4]} + columns = ['a', 'b'] + embeddings = [1.1, 2.2, 3.3, 4.4] + + expected_output = [{'a': 1.1, 'b': 2.2}, {'a': 3.3, 'b': 4.4}] self.assertEqual( - base._convert_list_of_dicts_to_dict_of_lists(input_list), - expected_output) + base._dict_output_fn(columns, input_list, embeddings), expected_output) - def test_list_of_dicts_to_dict_of_lists_on_list_inputs(self): + def test_dict_input_fn_on_list_inputs(self): input_list = [{'a': [1, 2, 10], 'b': 3}, {'a': [1], 'b': 5}] - expected_output = {'a': [[1, 2, 10], [1]], 'b': [3, 5]} - self.assertEqual( - base._convert_list_of_dicts_to_dict_of_lists(input_list), - expected_output) + columns = ['a', 'b'] - def test_dict_of_lists_to_lists_of_dict_normal(self): - input_dict = {'a': [1, 3], 'b': [2, 4]} - expected_output = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] - self.assertEqual( - base._convert_dict_of_lists_to_lists_of_dict(input_dict), - expected_output) + expected_output = [[1, 2, 10], 3, [1], 5] + self.assertEqual(base._dict_input_fn(columns, input_list), expected_output) - def test_dict_of_lists_to_lists_of_dict_unequal_length(self): - input_dict = {'a': [1, 3], 'b': [2]} - with self.assertRaises(AssertionError): - base._convert_dict_of_lists_to_lists_of_dict(input_dict) + def test_dict_output_fn_on_list_inputs(self): + input_list = [{'a': [1, 2, 10], 'b': 3}, {'a': [1], 'b': 5}] + columns = ['a', 'b'] + embeddings = [1.1, 2.2, 3.3, 4.4] + + expected_output = [{'a': 1.1, 'b': 2.2}, {'a': 3.3, 'b': 4.4}] + self.assertEqual( + base._dict_output_fn(columns, input_list, embeddings), expected_output) class TestJsonPickleTransformAttributeManager(unittest.TestCase): diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py index 2162ed050c42..e492cb164222 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py @@ -133,7 +133,7 @@ def __init__( max_batch_size: The maximum batch size to be used for inference. large_model: Whether to share the model across processes. """ - super().__init__(columns, **kwargs) + super().__init__(columns=columns, **kwargs) self.model_name = model_name self.max_seq_length = max_seq_length self.image_model = image_model @@ -219,7 +219,7 @@ def __init__( api_url: Optional[str] = None, **kwargs, ): - super().__init__(columns, **kwargs) + super().__init__(columns=columns, **kwargs) self._authorization_token = {"Authorization": f"Bearer {hf_token}"} self._model_name = model_name self.hf_token = hf_token diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index 6fe8320e758b..6df505508ae9 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -19,20 +19,27 @@ # Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long # to install Vertex AI Python SDK. +import logging +import time from collections.abc import Iterable from collections.abc import Sequence from typing import Any from typing import Optional +from google.api_core.exceptions import ServerError +from google.api_core.exceptions import TooManyRequests from google.auth.credentials import Credentials import apache_beam as beam import vertexai +from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler +from apache_beam.metrics.metric import Metrics from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import RunInference from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import _ImageEmbeddingHandler from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.utils import retry from vertexai.language_models import TextEmbeddingInput from vertexai.language_models import TextEmbeddingModel from vertexai.vision_models import Image @@ -51,6 +58,26 @@ "CLUSTERING" ] _BATCH_SIZE = 5 # Vertex AI limits requests to 5 at a time. +_MSEC_TO_SEC = 1000 + +LOGGER = logging.getLogger("VertexAIEmbeddings") + + +def _retry_on_appropriate_gcp_error(exception): + """ + Retry filter that returns True if a returned HTTP error code is 5xx or 429. + This is used to retry remote requests that fail, most notably 429 + (TooManyRequests.) + + Args: + exception: the returned exception encountered during the request/response + loop. + + Returns: + boolean indication whether or not the exception is a Server Error (5xx) or + a TooManyRequests (429) error. + """ + return isinstance(exception, (TooManyRequests, ServerError)) class _VertexAITextEmbeddingHandler(ModelHandler): @@ -74,6 +101,41 @@ def __init__( self.task_type = task_type self.title = title + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + VertexAIImageEmbeddings, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=1, bucket_ms=1, overload_ratio=2) + + @retry.with_exponential_backoff( + num_retries=5, retry_filter=_retry_on_appropriate_gcp_error) + def get_request( + self, + text_batch: Sequence[TextEmbeddingInput], + model: MultiModalEmbeddingModel, + throttle_delay_secs: int): + while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + try: + req_time = time.time() + prediction = model.get_embeddings(text_batch) + self.throttler.successful_request(req_time * _MSEC_TO_SEC) + return prediction + except TooManyRequests as e: + LOGGER.warning("request was limited by the service with code %i", e.code) + raise + except Exception as e: + LOGGER.error("unexpected exception raised as part of request, got %s", e) + raise + def run_inference( self, batch: Sequence[str], @@ -89,7 +151,8 @@ def run_inference( text=text, title=self.title, task_type=self.task_type) for text in text_batch ] - embeddings_batch = model.get_embeddings(text_batch) + embeddings_batch = self.get_request( + text_batch=text_batch, model=model, throttle_delay_secs=5) embeddings.extend([el.values for el in embeddings_batch]) return embeddings @@ -173,6 +236,41 @@ def __init__( self.model_name = model_name self.dimension = dimension + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + VertexAIImageEmbeddings, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=1, bucket_ms=1, overload_ratio=2) + + @retry.with_exponential_backoff( + num_retries=5, retry_filter=_retry_on_appropriate_gcp_error) + def get_request( + self, + img: Image, + model: MultiModalEmbeddingModel, + throttle_delay_secs: int): + while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + try: + req_time = time.time() + prediction = model.get_embeddings(image=img, dimension=self.dimension) + self.throttler.successful_request(req_time * _MSEC_TO_SEC) + return prediction + except TooManyRequests as e: + LOGGER.warning("request was limited by the service with code %i", e.code) + raise + except Exception as e: + LOGGER.error("unexpected exception raised as part of request, got %s", e) + raise + def run_inference( self, batch: Sequence[Image], @@ -182,8 +280,7 @@ def run_inference( embeddings = [] # Maximum request size for muli-model embedding models is 1. for img in batch: - embedding_response = model.get_embeddings( - image=img, dimension=self.dimension) + embedding_response = self.get_request(img, model, throttle_delay_secs=5) embeddings.append(embedding_response.image_embedding) return embeddings diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 97996bd6cbb2..4e65156f3bc7 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -42,6 +42,7 @@ import re import sys import time +import traceback import warnings from copy import copy from datetime import datetime @@ -557,13 +558,11 @@ def _cached_gcs_file_copy(self, from_path, to_path, sha256): source_file_names=[cached_path], destination_file_names=[to_path]) _LOGGER.info('Copied cached artifact from %s to %s', from_path, to_path) - @retry.with_exponential_backoff( - retry_filter=retry.retry_on_server_errors_and_timeout_filter) def _uncached_gcs_file_copy(self, from_path, to_path): to_folder, to_name = os.path.split(to_path) total_size = os.path.getsize(from_path) - with open(from_path, 'rb') as f: - self.stage_file(to_folder, to_name, f, total_size=total_size) + self.stage_file_with_retry( + to_folder, to_name, from_path, total_size=total_size) def _stage_resources(self, pipeline, options): google_cloud_options = options.view_as(GoogleCloudOptions) @@ -692,6 +691,41 @@ def stage_file( (gcs_or_local_path, e)) raise + @retry.with_exponential_backoff( + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def stage_file_with_retry( + self, + gcs_or_local_path, + file_name, + stream_or_path, + mime_type='application/octet-stream', + total_size=None): + + if isinstance(stream_or_path, str): + path = stream_or_path + with open(path, 'rb') as stream: + self.stage_file( + gcs_or_local_path, file_name, stream, mime_type, total_size) + elif isinstance(stream_or_path, io.IOBase): + stream = stream_or_path + try: + self.stage_file( + gcs_or_local_path, file_name, stream, mime_type, total_size) + except Exception as exn: + if stream.seekable(): + # reset cursor for possible retrying + stream.seek(0) + raise exn + else: + raise retry.PermanentException( + "Skip retrying because we caught exception:" + + ''.join(traceback.format_exception_only(exn.__class__, exn)) + + ', but the stream is not seekable.') + else: + raise retry.PermanentException( + "Skip retrying because type " + str(type(stream_or_path)) + + "stream_or_path is unsupported.") + @retry.no_retries # Using no_retries marks this as an integration point. def create_job(self, job): """Creates job description. May stage and/or submit for remote execution.""" @@ -703,7 +737,7 @@ def create_job(self, job): job.options.view_as(GoogleCloudOptions).template_location) if job.options.view_as(DebugOptions).lookup_experiment('upload_graph'): - self.stage_file( + self.stage_file_with_retry( job.options.view_as(GoogleCloudOptions).staging_location, "dataflow_graph.json", io.BytesIO(job.json().encode('utf-8'))) @@ -718,7 +752,7 @@ def create_job(self, job): if job_location: gcs_or_local_path = os.path.dirname(job_location) file_name = os.path.basename(job_location) - self.stage_file( + self.stage_file_with_retry( gcs_or_local_path, file_name, io.BytesIO(job.json().encode('utf-8'))) if not template_location: @@ -790,7 +824,7 @@ def create_job_description(self, job): resources = self._stage_resources(job.proto_pipeline, job.options) # Stage proto pipeline. - self.stage_file( + self.stage_file_with_retry( job.google_cloud_options.staging_location, shared_names.STAGED_PIPELINE_FILENAME, io.BytesIO(job.proto_pipeline.SerializeToString())) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 6587e619a500..d055065cb9d9 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -19,11 +19,13 @@ # pytype: skip-file +import io import itertools import json import logging import os import sys +import time import unittest import mock @@ -42,6 +44,7 @@ from apache_beam.transforms import DoFn from apache_beam.transforms import ParDo from apache_beam.transforms.environments import DockerEnvironment +from apache_beam.utils import retry # Protect against environments where apitools library is not available. # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports @@ -1064,7 +1067,11 @@ def test_graph_is_uploaded(self): side_effect=None): client.create_job(job) client.stage_file.assert_called_once_with( - mock.ANY, "dataflow_graph.json", mock.ANY) + mock.ANY, + "dataflow_graph.json", + mock.ANY, + 'application/octet-stream', + None) client.create_job_description.assert_called_once() def test_create_job_returns_existing_job(self): @@ -1174,8 +1181,18 @@ def test_template_file_generation_with_upload_graph(self): client.create_job(job) client.stage_file.assert_has_calls([ - mock.call(mock.ANY, 'dataflow_graph.json', mock.ANY), - mock.call(mock.ANY, 'template', mock.ANY) + mock.call( + mock.ANY, + 'dataflow_graph.json', + mock.ANY, + 'application/octet-stream', + None), + mock.call( + mock.ANY, + 'template', + mock.ANY, + 'application/octet-stream', + None) ]) client.create_job_description.assert_called_once() # template is generated, but job should not be submitted to the @@ -1653,6 +1670,93 @@ def exists_return_value(*args): })) self.assertEqual(pipeline, pipeline_expected) + def test_stage_file_with_retry(self): + def effect(self, *args, **kwargs): + nonlocal count + count += 1 + # Fail the first two calls and succeed afterward + if count <= 2: + raise Exception("This exception is raised for testing purpose.") + + class Unseekable(io.IOBase): + def seekable(self): + return False + + pipeline_options = PipelineOptions([ + '--project', + 'test_project', + '--job_name', + 'test_job_name', + '--temp_location', + 'gs://test-location/temp', + ]) + pipeline_options.view_as(GoogleCloudOptions).no_auth = True + client = apiclient.DataflowApplicationClient(pipeline_options) + + with mock.patch.object(client, 'stage_file') as mock_stage_file: + mock_stage_file.side_effect = effect + + with mock.patch.object(time, 'sleep') as mock_sleep: + with mock.patch("builtins.open", + mock.mock_open(read_data="data")) as mock_file_open: + count = 0 + # calling with a file name + client.stage_file_with_retry( + "/to", "new_name", "/from/old_name", total_size=4) + self.assertEqual(mock_stage_file.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + self.assertEqual(mock_file_open.call_count, 3) + + count = 0 + mock_stage_file.reset_mock() + mock_sleep.reset_mock() + mock_file_open.reset_mock() + + # calling with a seekable stream + client.stage_file_with_retry( + "/to", "new_name", io.BytesIO(b'test'), total_size=4) + self.assertEqual(mock_stage_file.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + # no open() is called if a stream is provided + mock_file_open.assert_not_called() + + count = 0 + mock_sleep.reset_mock() + mock_file_open.reset_mock() + mock_stage_file.reset_mock() + + # calling with an unseekable stream + self.assertRaises( + retry.PermanentException, + client.stage_file_with_retry, + "/to", + "new_name", + Unseekable(), + total_size=4) + # Unseekable streams are staged once. If staging fails, no retries are + # attempted. + self.assertEqual(mock_stage_file.call_count, 1) + mock_sleep.assert_not_called() + mock_file_open.assert_not_called() + + count = 0 + mock_sleep.reset_mock() + mock_file_open.reset_mock() + mock_stage_file.reset_mock() + + # calling with something else + self.assertRaises( + retry.PermanentException, + client.stage_file_with_retry, + "/to", + "new_name", + object(), + total_size=4) + # No staging will be called for wrong arg type + mock_stage_file.assert_not_called() + mock_sleep.assert_not_called() + mock_file_open.assert_not_called() + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/dataflow/internal/names.py b/sdks/python/apache_beam/runners/dataflow/internal/names.py index ac575e82717e..65bedf39fb2c 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/names.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/names.py @@ -34,6 +34,6 @@ # Unreleased sdks use container image tag specified below. # Update this tag whenever there is a change that # requires changes to SDK harness container or SDK harness launcher. -BEAM_DEV_SDK_CONTAINER_TAG = 'beam-master-20241118' +BEAM_DEV_SDK_CONTAINER_TAG = 'beam-master-20250102' DATAFLOW_CONTAINER_IMAGE_REPOSITORY = 'gcr.io/cloud-dataflow/v1beta3' diff --git a/sdks/python/apache_beam/runners/direct/direct_metrics.py b/sdks/python/apache_beam/runners/direct/direct_metrics.py index d20849d769af..5beb19d4610a 100644 --- a/sdks/python/apache_beam/runners/direct/direct_metrics.py +++ b/sdks/python/apache_beam/runners/direct/direct_metrics.py @@ -27,6 +27,7 @@ from typing import Any from typing import SupportsInt +from apache_beam.metrics.cells import BoundedTrieData from apache_beam.metrics.cells import DistributionData from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import StringSetData @@ -102,6 +103,8 @@ def __init__(self): lambda: DirectMetric(GenericAggregator(GaugeData))) self._string_sets = defaultdict( lambda: DirectMetric(GenericAggregator(StringSetData))) + self._bounded_tries = defaultdict( + lambda: DirectMetric(GenericAggregator(BoundedTrieData))) def _apply_operation(self, bundle, updates, op): for k, v in updates.counters.items(): @@ -116,6 +119,9 @@ def _apply_operation(self, bundle, updates, op): for k, v in updates.string_sets.items(): op(self._string_sets[k], bundle, v) + for k, v in updates.bounded_tries.items(): + op(self._bounded_tries[k], bundle, v) + def commit_logical(self, bundle, updates): op = lambda obj, bundle, update: obj.commit_logical(bundle, update) self._apply_operation(bundle, updates, op) @@ -157,12 +163,20 @@ def query(self, filter=None): v.extract_latest_attempted()) for k, v in self._string_sets.items() if self.matches(filter, k) ] + bounded_tries = [ + MetricResult( + MetricKey(k.step, k.metric), + v.extract_committed(), + v.extract_latest_attempted()) for k, + v in self._bounded_tries.items() if self.matches(filter, k) + ] return { self.COUNTERS: counters, self.DISTRIBUTIONS: distributions, self.GAUGES: gauges, - self.STRINGSETS: string_sets + self.STRINGSETS: string_sets, + self.BOUNDED_TRIES: bounded_tries, } diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py b/sdks/python/apache_beam/runners/direct/direct_runner_test.py index d8f1ea097b88..1af5f1bc7bea 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py @@ -78,6 +78,8 @@ def process(self, element): distro.update(element) str_set = Metrics.string_set(self.__class__, 'element_str_set') str_set.add(str(element % 4)) + Metrics.bounded_trie(self.__class__, 'element_bounded_trie').add( + ("a", "b", str(element % 4))) return [element] p = Pipeline(DirectRunner()) @@ -124,6 +126,14 @@ def process(self, element): hc.assert_that(len(str_set_result.committed), hc.equal_to(4)) hc.assert_that(len(str_set_result.attempted), hc.equal_to(4)) + bounded_trie_results = metrics['bounded_tries'][0] + hc.assert_that( + bounded_trie_results.key, + hc.equal_to( + MetricKey('Do', MetricName(namespace, 'element_bounded_trie')))) + hc.assert_that(bounded_trie_results.committed.size(), hc.equal_to(4)) + hc.assert_that(bounded_trie_results.attempted.size(), hc.equal_to(4)) + def test_create_runner(self): self.assertTrue(isinstance(create_runner('DirectRunner'), DirectRunner)) self.assertTrue( diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py index 4dc2446fdd9d..30f1a4c06025 100644 --- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py @@ -303,7 +303,7 @@ def test_flattened_side_input(self): super().test_flattened_side_input(with_transcoding=False) def test_metrics(self): - super().test_metrics(check_gauge=False) + super().test_metrics(check_gauge=False, check_bounded_trie=False) def test_sdf_with_watermark_tracking(self): raise unittest.SkipTest("BEAM-2939") diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index 1ed21942d28f..95bcb7567918 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -1536,16 +1536,18 @@ def __init__(self, step_monitoring_infos, user_metrics_only=True): self._distributions = {} self._gauges = {} self._string_sets = {} + self._bounded_tries = {} self._user_metrics_only = user_metrics_only self._monitoring_infos = step_monitoring_infos for smi in step_monitoring_infos.values(): - counters, distributions, gauges, string_sets = \ - portable_metrics.from_monitoring_infos(smi, user_metrics_only) + counters, distributions, gauges, string_sets, bounded_tries = ( + portable_metrics.from_monitoring_infos(smi, user_metrics_only)) self._counters.update(counters) self._distributions.update(distributions) self._gauges.update(gauges) self._string_sets.update(string_sets) + self._bounded_tries.update(bounded_tries) def query(self, filter=None): counters = [ @@ -1564,12 +1566,17 @@ def query(self, filter=None): MetricResult(k, v, v) for k, v in self._string_sets.items() if self.matches(filter, k) ] + bounded_tries = [ + MetricResult(k, v, v) for k, + v in self._bounded_tries.items() if self.matches(filter, k) + ] return { self.COUNTERS: counters, self.DISTRIBUTIONS: distributions, self.GAUGES: gauges, - self.STRINGSETS: string_sets + self.STRINGSETS: string_sets, + self.BOUNDED_TRIES: bounded_tries, } def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 1309e7c74abc..3f036ab27f6e 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -1209,13 +1209,14 @@ def expand(self, pcolls): pcoll_b = p | 'b' >> beam.Create(['b']) assert_that((pcoll_a, pcoll_b) | First(), equal_to(['a'])) - def test_metrics(self, check_gauge=True): + def test_metrics(self, check_gauge=True, check_bounded_trie=False): p = self.create_pipeline() counter = beam.metrics.Metrics.counter('ns', 'counter') distribution = beam.metrics.Metrics.distribution('ns', 'distribution') gauge = beam.metrics.Metrics.gauge('ns', 'gauge') string_set = beam.metrics.Metrics.string_set('ns', 'string_set') + bounded_trie = beam.metrics.Metrics.bounded_trie('ns', 'bounded_trie') elements = ['a', 'zzz'] pcoll = p | beam.Create(elements) @@ -1225,6 +1226,7 @@ def test_metrics(self, check_gauge=True): pcoll | 'dist' >> beam.FlatMap(lambda x: distribution.update(len(x))) pcoll | 'gauge' >> beam.FlatMap(lambda x: gauge.set(3)) pcoll | 'string_set' >> beam.FlatMap(lambda x: string_set.add(x)) + pcoll | 'bounded_trie' >> beam.FlatMap(lambda x: bounded_trie.add(tuple(x))) res = p.run() res.wait_until_finish() @@ -1248,6 +1250,14 @@ def test_metrics(self, check_gauge=True): .with_name('string_set'))['string_sets'] self.assertEqual(str_set.committed, set(elements)) + if check_bounded_trie: + bounded_trie, = res.metrics().query(beam.metrics.MetricsFilter() + .with_name('bounded_trie'))['bounded_tries'] + self.assertEqual(bounded_trie.committed.size(), 2) + for element in elements: + self.assertTrue( + bounded_trie.committed.contains(tuple(element)), element) + def test_callbacks_with_exception(self): elements_list = ['1', '2'] diff --git a/sdks/python/apache_beam/runners/portability/portable_metrics.py b/sdks/python/apache_beam/runners/portability/portable_metrics.py index 5bc3e0539181..e92d33910415 100644 --- a/sdks/python/apache_beam/runners/portability/portable_metrics.py +++ b/sdks/python/apache_beam/runners/portability/portable_metrics.py @@ -42,6 +42,7 @@ def from_monitoring_infos(monitoring_info_list, user_metrics_only=False): distributions = {} gauges = {} string_sets = {} + bounded_tries = {} for mi in monitoring_info_list: if (user_metrics_only and not monitoring_infos.is_user_monitoring_info(mi)): @@ -62,8 +63,10 @@ def from_monitoring_infos(monitoring_info_list, user_metrics_only=False): gauges[key] = metric_result elif monitoring_infos.is_string_set(mi): string_sets[key] = metric_result + elif monitoring_infos.is_bounded_trie(mi): + bounded_tries[key] = metric_result - return counters, distributions, gauges, string_sets + return counters, distributions, gauges, string_sets, bounded_tries def _create_metric_key(monitoring_info): diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index ba48bbec6d3a..fe9dcfa62b29 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -437,7 +437,7 @@ def _combine(committed, attempted, filter): ] def query(self, filter=None): - counters, distributions, gauges, stringsets = [ + counters, distributions, gauges, stringsets, bounded_tries = [ self._combine(x, y, filter) for x, y in zip(self.committed, self.attempted) ] @@ -446,7 +446,8 @@ def query(self, filter=None): self.COUNTERS: counters, self.DISTRIBUTIONS: distributions, self.GAUGES: gauges, - self.STRINGSETS: stringsets + self.STRINGSETS: stringsets, + self.BOUNDED_TRIES: bounded_tries, } diff --git a/sdks/python/apache_beam/runners/portability/prism_runner_test.py b/sdks/python/apache_beam/runners/portability/prism_runner_test.py index bc72d551f966..337ac9919487 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner_test.py @@ -231,6 +231,9 @@ def test_pack_combiners(self): "Requires Prism to support coder:" + " 'beam:coder:tuple:v1'. https://github.com/apache/beam/issues/32636") + def test_metrics(self): + super().test_metrics(check_bounded_trie=False) + # Inherits all other tests. diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index b091220a06b5..3cb1a26b77f1 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -201,7 +201,9 @@ def __init__( self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id, data_buffer_time_limit_ms) self._state_handler_factory = GrpcStateHandlerFactory( - self._state_cache, credentials) + state_cache=self._state_cache, + credentials=credentials, + worker_id=self._worker_id) self._profiler_factory = profiler_factory self.data_sampler = data_sampler self.runner_capabilities = runner_capabilities @@ -893,13 +895,14 @@ class GrpcStateHandlerFactory(StateHandlerFactory): Caches the created channels by ``state descriptor url``. """ - def __init__(self, state_cache, credentials=None): - # type: (StateCache, Optional[grpc.ChannelCredentials]) -> None + def __init__(self, state_cache, credentials=None, worker_id=None): + # type: (StateCache, Optional[grpc.ChannelCredentials], Optional[str]) -> None self._state_handler_cache = {} # type: Dict[str, CachingStateHandler] self._lock = threading.Lock() self._throwing_state_handler = ThrowingStateHandler() self._credentials = credentials self._state_cache = state_cache + self._worker_id = worker_id def create_state_handler(self, api_service_descriptor): # type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler @@ -926,7 +929,7 @@ def create_state_handler(self, api_service_descriptor): _LOGGER.info('State channel established.') # Add workerId to the grpc channel grpc_channel = grpc.intercept_channel( - grpc_channel, WorkerIdInterceptor()) + grpc_channel, WorkerIdInterceptor(self._worker_id)) self._state_handler_cache[url] = GlobalCachingStateHandler( self._state_cache, GrpcStateHandler( diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py index 2271b4495d79..ecd4dc4e02c0 100644 --- a/sdks/python/apache_beam/runners/worker/worker_status.py +++ b/sdks/python/apache_beam/runners/worker/worker_status.py @@ -151,6 +151,7 @@ def __init__( bundle_process_cache=None, state_cache=None, enable_heap_dump=False, + worker_id=None, log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS): """Initialize FnApiWorkerStatusHandler. @@ -164,7 +165,8 @@ def __init__( self._state_cache = state_cache ch = GRPCChannelFactory.insecure_channel(status_address) grpc.channel_ready_future(ch).result(timeout=60) - self._status_channel = grpc.intercept_channel(ch, WorkerIdInterceptor()) + self._status_channel = grpc.intercept_channel( + ch, WorkerIdInterceptor(worker_id)) self._status_stub = beam_fn_api_pb2_grpc.BeamFnWorkerStatusStub( self._status_channel) self._responses = queue.Queue() diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py new file mode 100644 index 000000000000..223b973e5fbe --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples.inference import tensorflow_mnist_classification +from apache_beam.testing.load_tests.dataflow_cost_benchmark import DataflowCostBenchmark + + +class TensorflowMNISTClassificationCostBenchmark(DataflowCostBenchmark): + def __init__(self): + super().__init__() + + def test(self): + extra_opts = {} + extra_opts['input'] = self.pipeline.get_option('input_file') + extra_opts['output'] = self.pipeline.get_option('output_file') + extra_opts['model_path'] = self.pipeline.get_option('model') + self.result = tensorflow_mnist_classification.run( + self.pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + TensorflowMNISTClassificationCostBenchmark().run() diff --git a/sdks/python/apache_beam/testing/benchmarks/wordcount/__init__.py b/sdks/python/apache_beam/testing/benchmarks/wordcount/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/wordcount/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py b/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py new file mode 100644 index 000000000000..513ede47e80a --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples import wordcount +from apache_beam.testing.load_tests.dataflow_cost_benchmark import DataflowCostBenchmark + + +class WordcountCostBenchmark(DataflowCostBenchmark): + def __init__(self): + super().__init__() + + def test(self): + extra_opts = {} + extra_opts['output'] = self.pipeline.get_option('output_file') + self.result = wordcount.run( + self.pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + WordcountCostBenchmark().run() diff --git a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py new file mode 100644 index 000000000000..96a1cd31e298 --- /dev/null +++ b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: skip-file + +import logging +import time +from typing import Any +from typing import Optional + +import apache_beam.testing.load_tests.dataflow_cost_consts as costs +from apache_beam.metrics.execution import MetricResult +from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult +from apache_beam.runners.runner import PipelineState +from apache_beam.testing.load_tests.load_test import LoadTest + + +class DataflowCostBenchmark(LoadTest): + """Base class for Dataflow performance tests which export metrics to + external databases: BigQuery or/and InfluxDB. Calculates the expected cost + for running the job on Dataflow in region us-central1. + + Refer to :class:`~apache_beam.testing.load_tests.LoadTestOptions` for more + information on the required pipeline options. + + If using InfluxDB with Basic HTTP authentication enabled, provide the + following environment options: `INFLUXDB_USER` and `INFLUXDB_USER_PASSWORD`. + + If the hardware configuration for the job includes use of a GPU, please + specify the version in use with the Accelerator enumeration. This is used to + calculate the cost of the job later, as different accelerators have different + billing rates per hour of use. + """ + def __init__( + self, + metrics_namespace: Optional[str] = None, + is_streaming: bool = False, + gpu: Optional[costs.Accelerator] = None): + self.is_streaming = is_streaming + self.gpu = gpu + super().__init__(metrics_namespace=metrics_namespace) + + def run(self): + try: + self.test() + if not hasattr(self, 'result'): + self.result = self.pipeline.run() + # Defaults to waiting forever unless timeout has been set + state = self.result.wait_until_finish(duration=self.timeout_ms) + assert state != PipelineState.FAILED + logging.info( + 'Pipeline complete, sleeping for 4 minutes to allow resource ' + 'metrics to populate.') + time.sleep(240) + self.extra_metrics = self._retrieve_cost_metrics(self.result) + self._metrics_monitor.publish_metrics(self.result, self.extra_metrics) + finally: + self.cleanup() + + def _retrieve_cost_metrics(self, + result: DataflowPipelineResult) -> dict[str, Any]: + job_id = result.job_id() + metrics = result.metrics().all_metrics(job_id) + metrics_dict = self._process_metrics_list(metrics) + logging.info(metrics_dict) + cost = 0.0 + if (self.is_streaming): + cost += metrics_dict.get( + "TotalVcpuTime", 0.0) / 3600 * costs.VCPU_PER_HR_STREAMING + cost += ( + metrics_dict.get("TotalMemoryUsage", 0.0) / + 1000) / 3600 * costs.MEM_PER_GB_HR_STREAMING + cost += metrics_dict.get( + "TotalStreamingDataProcessed", 0.0) * costs.SHUFFLE_PER_GB_STREAMING + else: + cost += metrics_dict.get( + "TotalVcpuTime", 0.0) / 3600 * costs.VCPU_PER_HR_BATCH + cost += ( + metrics_dict.get("TotalMemoryUsage", 0.0) / + 1000) / 3600 * costs.MEM_PER_GB_HR_BATCH + cost += metrics_dict.get( + "TotalStreamingDataProcessed", 0.0) * costs.SHUFFLE_PER_GB_BATCH + if (self.gpu): + rate = costs.ACCELERATOR_TO_COST[self.gpu] + cost += metrics_dict.get("TotalGpuTime", 0.0) / 3600 * rate + cost += metrics_dict.get("TotalPdUsage", 0.0) / 3600 * costs.PD_PER_GB_HR + cost += metrics_dict.get( + "TotalSsdUsage", 0.0) / 3600 * costs.PD_SSD_PER_GB_HR + metrics_dict["EstimatedCost"] = cost + return metrics_dict + + def _process_metrics_list(self, + metrics: list[MetricResult]) -> dict[str, Any]: + system_metrics = {} + for entry in metrics: + metric_key = entry.key + metric = metric_key.metric + if metric_key.step == '' and metric.namespace == 'dataflow/v1b3': + if entry.committed is None: + entry.committed = 0.0 + system_metrics[metric.name] = entry.committed + return system_metrics diff --git a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_consts.py b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_consts.py new file mode 100644 index 000000000000..f291991b48bb --- /dev/null +++ b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_consts.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These values are Dataflow costs for running jobs in us-central1. +# The cost values are found at https://cloud.google.com/dataflow/pricing + +from enum import Enum + +VCPU_PER_HR_BATCH = 0.056 +VCPU_PER_HR_STREAMING = 0.069 +MEM_PER_GB_HR_BATCH = 0.003557 +MEM_PER_GB_HR_STREAMING = 0.0035557 +PD_PER_GB_HR = 0.000054 +PD_SSD_PER_GB_HR = 0.000298 +SHUFFLE_PER_GB_BATCH = 0.011 +SHUFFLE_PER_GB_STREAMING = 0.018 + +# GPU Resource Pricing +P100_PER_GPU_PER_HOUR = 1.752 +V100_PER_GPU_PER_HOUR = 2.976 +T4_PER_GPU_PER_HOUR = 0.42 +P4_PER_GPU_PER_HOUR = 0.72 +L4_PER_GPU_PER_HOUR = 0.672 +A100_40GB_PER_GPU_PER_HOUR = 3.72 +A100_80GB_PER_GPU_PER_HOUR = 4.7137 + + +class Accelerator(Enum): + P100 = 1 + V100 = 2 + T4 = 3 + P4 = 4 + L4 = 5 + A100_40GB = 6 + A100_80GB = 7 + + +ACCELERATOR_TO_COST: dict[Accelerator, float] = { + Accelerator.P100: P100_PER_GPU_PER_HOUR, + Accelerator.V100: V100_PER_GPU_PER_HOUR, + Accelerator.T4: T4_PER_GPU_PER_HOUR, + Accelerator.P4: P4_PER_GPU_PER_HOUR, + Accelerator.L4: L4_PER_GPU_PER_HOUR, + Accelerator.A100_40GB: A100_40GB_PER_GPU_PER_HOUR, + Accelerator.A100_80GB: A100_80GB_PER_GPU_PER_HOUR, +} diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 9c798d3ce6dc..b420d1d66d09 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2117,15 +2117,13 @@ def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name r""":func:`MapTuple` is like :func:`Map` but expects tuple inputs and flattens them into multiple input arguments. - beam.MapTuple(lambda a, b, ...: ...) - In other words - beam.MapTuple(fn) + "SwapKV" >> beam.Map(lambda kv: (kv[1], kv[0])) is equivalent to - beam.Map(lambda element, ...: fn(\*element, ...)) + "SwapKV" >> beam.MapTuple(lambda k, v: (v, k)) This can be useful when processing a PCollection of tuples (e.g. key-value pairs). @@ -2191,19 +2189,13 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name r""":func:`FlatMapTuple` is like :func:`FlatMap` but expects tuple inputs and flattens them into multiple input arguments. - beam.FlatMapTuple(lambda a, b, ...: ...) - - is equivalent to Python 2 - - beam.FlatMap(lambda (a, b, ...), ...: ...) - In other words - beam.FlatMapTuple(fn) + beam.FlatMap(lambda start_end: range(start_end[0], start_end[1])) is equivalent to - beam.FlatMap(lambda element, ...: fn(\*element, ...)) + beam.FlatMapTuple(lambda start, end: range(start, end)) This can be useful when processing a PCollection of tuples (e.g. key-value pairs). @@ -2238,7 +2230,7 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name if defaults or args or kwargs: wrapper = lambda x, *args, **kwargs: fn(*(tuple(x) + args), **kwargs) else: - wrapper = lambda x: fn(*x) + wrapper = lambda x: fn(*tuple(x)) # Proxy the type-hint information from the original function to this new # wrapped function. diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index e44f7482dc61..9ca5886f4cc2 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -239,7 +239,8 @@ def dict_to_row(schema_proto, py_value): extra = set(py_value.keys()) - set(row_type._fields) if extra: raise ValueError( - f"Unknown fields: {extra}. Valid fields: {row_type._fields}") + f"Transform '{self.identifier()}' was configured with unknown " + f"fields: {extra}. Valid fields: {set(row_type._fields)}") return row_type( *[ dict_to_row_recursive( @@ -962,14 +963,14 @@ def __init__( self, path_to_jar, extra_args=None, classpath=None, append_args=None): if extra_args and append_args: raise ValueError('Only one of extra_args or append_args may be provided') - self._path_to_jar = path_to_jar + self.path_to_jar = path_to_jar self._extra_args = extra_args self._classpath = classpath or [] self._service_count = 0 self._append_args = append_args or [] def is_existing_service(self): - return subprocess_server.is_service_endpoint(self._path_to_jar) + return subprocess_server.is_service_endpoint(self.path_to_jar) @staticmethod def _expand_jars(jar): @@ -997,7 +998,7 @@ def _expand_jars(jar): def _default_args(self): """Default arguments to be used by `JavaJarExpansionService`.""" - to_stage = ','.join([self._path_to_jar] + sum(( + to_stage = ','.join([self.path_to_jar] + sum(( JavaJarExpansionService._expand_jars(jar) for jar in self._classpath or []), [])) args = ['{{PORT}}', f'--filesToStage={to_stage}'] @@ -1009,8 +1010,8 @@ def _default_args(self): def __enter__(self): if self._service_count == 0: - self._path_to_jar = subprocess_server.JavaJarServer.local_jar( - self._path_to_jar) + self.path_to_jar = subprocess_server.JavaJarServer.local_jar( + self.path_to_jar) if self._extra_args is None: self._extra_args = self._default_args() + self._append_args # Consider memoizing these servers (with some timeout). @@ -1018,7 +1019,7 @@ def __enter__(self): 'Starting a JAR-based expansion service from JAR %s ' + ( 'and with classpath: %s' % self._classpath if self._classpath else ''), - self._path_to_jar) + self.path_to_jar) classpath_urls = [ subprocess_server.JavaJarServer.local_jar(path) for jar in self._classpath @@ -1026,7 +1027,7 @@ def __enter__(self): ] self._service_provider = subprocess_server.JavaJarServer( ExpansionAndArtifactRetrievalStub, - self._path_to_jar, + self.path_to_jar, self._extra_args, classpath=classpath_urls) self._service = self._service_provider.__enter__() diff --git a/sdks/python/apache_beam/transforms/external_transform_provider.py b/sdks/python/apache_beam/transforms/external_transform_provider.py index 117c7f7c9b93..b22cd4b24cb6 100644 --- a/sdks/python/apache_beam/transforms/external_transform_provider.py +++ b/sdks/python/apache_beam/transforms/external_transform_provider.py @@ -26,6 +26,7 @@ from apache_beam.transforms import PTransform from apache_beam.transforms.external import BeamJarExpansionService +from apache_beam.transforms.external import JavaJarExpansionService from apache_beam.transforms.external import SchemaAwareExternalTransform from apache_beam.transforms.external import SchemaTransformsConfig from apache_beam.typehints.schemas import named_tuple_to_schema @@ -133,37 +134,57 @@ class ExternalTransformProvider: (see the `urn_pattern` parameter). These classes are generated when :class:`ExternalTransformProvider` is - initialized. We need to give it one or more expansion service addresses that - are already up and running: - >>> provider = ExternalTransformProvider(["localhost:12345", - ... "localhost:12121"]) - We can also give it the gradle target of a standard Beam expansion service: - >>> provider = ExternalTransform(BeamJarExpansionService( - ... "sdks:java:io:google-cloud-platform:expansion-service:shadowJar")) - Let's take a look at the output of :func:`get_available()` to know the - available transforms in the expansion service(s) we provided: + initialized. You can give it an expansion service address that is already + up and running: + + >>> provider = ExternalTransformProvider("localhost:12345") + + Or you can give it the path to an expansion service Jar file: + + >>> provider = ExternalTransformProvider(JavaJarExpansionService( + "path/to/expansion-service.jar")) + + Or you can give it the gradle target of a standard Beam expansion service: + + >>> provider = ExternalTransformProvider(BeamJarExpansionService( + "sdks:java:io:google-cloud-platform:expansion-service:shadowJar")) + + Note that you can provide a list of these services: + + >>> provider = ExternalTransformProvider([ + "localhost:12345", + JavaJarExpansionService("path/to/expansion-service.jar"), + BeamJarExpansionService( + "sdks:java:io:google-cloud-platform:expansion-service:shadowJar")]) + + The output of :func:`get_available()` provides a list of available transforms + in the provided expansion service(s): + >>> provider.get_available() [('JdbcWrite', 'beam:schematransform:org.apache.beam:jdbc_write:v1'), ('BigtableRead', 'beam:schematransform:org.apache.beam:bigtable_read:v1'), ...] - Then retrieve a transform by :func:`get()`, :func:`get_urn()`, or by directly - accessing it as an attribute of :class:`ExternalTransformProvider`. - All of the following commands do the same thing: + You can retrieve a transform with :func:`get()`, :func:`get_urn()`, or by + directly accessing it as an attribute. The following lines all do the same + thing: + >>> provider.get('BigqueryStorageRead') >>> provider.get_urn( - ... 'beam:schematransform:org.apache.beam:bigquery_storage_read:v1') + 'beam:schematransform:org.apache.beam:bigquery_storage_read:v1') >>> provider.BigqueryStorageRead - You can inspect the transform's documentation to know more about it. This - returns some documentation only IF the underlying SchemaTransform - implementation provides any. + You can inspect the transform's documentation for more details. The following + returns the documentation provided by the underlying SchemaTransform. If no + such documentation is provided, this will be empty. + >>> import inspect >>> inspect.getdoc(provider.BigqueryStorageRead) Similarly, you can inspect the transform's signature to know more about its parameters, including their names, types, and any documentation that the underlying SchemaTransform may provide: + >>> inspect.signature(provider.BigqueryStorageRead) (query: 'typing.Union[str, NoneType]: The SQL query to be executed to...', row_restriction: 'typing.Union[str, NoneType]: Read only rows that match...', @@ -178,8 +199,6 @@ class ExternalTransformProvider: query=query, row_restriction=restriction) | 'Some processing' >> beam.Map(...)) - - Experimental; no backwards compatibility guarantees. """ def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN): f"""Initialize an ExternalTransformProvider @@ -188,6 +207,7 @@ def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN): A list of expansion services to discover transforms from. Supported forms: * a string representing the expansion service address + * a :attr:`JavaJarExpansionService` pointing to the path of a Java Jar * a :attr:`BeamJarExpansionService` pointing to a gradle target :param urn_pattern: The regular expression used to match valid transforms. In addition to @@ -213,11 +233,14 @@ def _create_wrappers(self): target = service if isinstance(service, BeamJarExpansionService): target = service.gradle_target + if isinstance(service, JavaJarExpansionService): + target = service.path_to_jar try: schematransform_configs = SchemaAwareExternalTransform.discover(service) except Exception as e: logging.exception( - "Encountered an error while discovering expansion service %s:\n%s", + "Encountered an error while discovering " + "expansion service at '%s':\n%s", target, e) continue @@ -249,7 +272,7 @@ def _create_wrappers(self): if skipped_urns: logging.info( - "Skipped URN(s) in %s that don't follow the pattern \"%s\": %s", + "Skipped URN(s) in '%s' that don't follow the pattern \"%s\": %s", target, self._urn_pattern, skipped_urns) @@ -262,7 +285,7 @@ def get_available(self) -> List[Tuple[str, str]]: return list(self._name_to_urn.items()) def get_all(self) -> Dict[str, ExternalTransform]: - """Get all ExternalTransform""" + """Get all ExternalTransforms""" return self._transforms def get(self, name) -> ExternalTransform: diff --git a/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py index 0dfa2aa19c51..20cb52335c76 100644 --- a/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py +++ b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py @@ -16,15 +16,13 @@ # import os -import secrets -import shutil -import tempfile import time import unittest import pytest import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -35,17 +33,14 @@ "EXPANSION_JARS environment var is not provided, " "indicating that jars have not been built") class ManagedIcebergIT(unittest.TestCase): - def setUp(self): - self._tempdir = tempfile.mkdtemp() - if not os.path.exists(self._tempdir): - os.mkdir(self._tempdir) - test_warehouse_name = 'test_warehouse_%d_%s' % ( - int(time.time()), secrets.token_hex(3)) - self.warehouse_path = os.path.join(self._tempdir, test_warehouse_name) - os.mkdir(self.warehouse_path) + WAREHOUSE = "gs://temp-storage-for-end-to-end-tests/xlang-python-using-java" - def tearDown(self): - shutil.rmtree(self._tempdir, ignore_errors=False) + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.args = self.test_pipeline.get_full_options_as_args() + self.args.extend([ + '--experiments=enable_managed_transforms', + ]) def _create_row(self, num: int): return beam.Row( @@ -57,24 +52,24 @@ def _create_row(self, num: int): def test_write_read_pipeline(self): iceberg_config = { - "table": "test.write_read", + "table": "test_iceberg_write_read.test_" + str(int(time.time())), "catalog_name": "default", "catalog_properties": { "type": "hadoop", - "warehouse": f"file://{self.warehouse_path}", + "warehouse": self.WAREHOUSE, } } rows = [self._create_row(i) for i in range(100)] expected_dicts = [row.as_dict() for row in rows] - with beam.Pipeline() as write_pipeline: + with beam.Pipeline(argv=self.args) as write_pipeline: _ = ( write_pipeline | beam.Create(rows) | beam.managed.Write(beam.managed.ICEBERG, config=iceberg_config)) - with beam.Pipeline() as read_pipeline: + with beam.Pipeline(argv=self.args) as read_pipeline: output_dicts = ( read_pipeline | beam.managed.Read(beam.managed.ICEBERG, config=iceberg_config) diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py index 44318fa44a8c..820f78fa9ef5 100644 --- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py +++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py @@ -21,6 +21,7 @@ import typing import unittest +from typing import Tuple import apache_beam as beam from apache_beam import pvalue @@ -999,5 +1000,22 @@ def filter_fn(element: int) -> bool: self.assertEqual(th.output_types, ((int, ), {})) +class TestFlatMapTuple(unittest.TestCase): + def test_flatmaptuple(self): + # Regression test. See + # https://github.com/apache/beam/issues/33014 + + def identity(x: Tuple[str, int]) -> Tuple[str, int]: + return x + + with beam.Pipeline() as p: + # Just checking that this doesn't raise an exception. + ( + p + | "Generate input" >> beam.Create([('P1', [2])]) + | "Flat" >> beam.FlatMapTuple(lambda k, vs: [(k, v) for v in vs]) + | "Identity" >> beam.Map(identity)) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/version.py b/sdks/python/apache_beam/version.py index 9974bb68bccf..39185712b141 100644 --- a/sdks/python/apache_beam/version.py +++ b/sdks/python/apache_beam/version.py @@ -17,4 +17,4 @@ """Apache Beam SDK version information and utilities.""" -__version__ = '2.62.0.dev' +__version__ = '2.63.0.dev' diff --git a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py index 3b497ed1efab..109e98410852 100644 --- a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py +++ b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py @@ -21,9 +21,11 @@ import os import random import unittest +from typing import Any from typing import Callable from typing import Dict from typing import List +from typing import Optional from typing import Union from unittest import mock @@ -34,11 +36,63 @@ from apache_beam.examples.snippets.util import assert_matches_stdout from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.yaml import yaml_provider from apache_beam.yaml import yaml_transform from apache_beam.yaml.readme_test import TestEnvironment from apache_beam.yaml.readme_test import replace_recursive +# Used to simulate Enrichment transform during tests +# The GitHub action that invokes these tests does not +# have gcp dependencies installed which is a prerequisite +# to apache_beam.transforms.enrichment.Enrichment as a top-level +# import. +@beam.ptransform.ptransform_fn +def test_enrichment( + pcoll, + enrichment_handler: str, + handler_config: Dict[str, Any], + timeout: Optional[float] = 30): + if enrichment_handler == 'BigTable': + row_key = handler_config['row_key'] + bt_data = INPUT_TABLES[( + 'BigTable', handler_config['instance_id'], handler_config['table_id'])] + products = {str(data[row_key]): data for data in bt_data} + + def _fn(row): + left = row._asdict() + right = products[str(left[row_key])] + left['product'] = left.get('product', None) or right + return beam.Row(**left) + elif enrichment_handler == 'BigQuery': + row_key = handler_config['fields'] + dataset, table = handler_config['table_name'].split('.')[-2:] + bq_data = INPUT_TABLES[('BigQuery', str(dataset), str(table))] + bq_data = { + tuple(str(data[key]) for key in row_key): data + for data in bq_data + } + + def _fn(row): + left = row._asdict() + right = bq_data[tuple(str(left[k]) for k in row_key)] + row = { + key: left.get(key, None) or right[key] + for key in {*left.keys(), *right.keys()} + } + return beam.Row(**row) + + else: + raise ValueError(f'{enrichment_handler} is not a valid enrichment_handler.') + + return pcoll | beam.Map(_fn) + + +TEST_PROVIDERS = { + 'TestEnrichment': test_enrichment, +} + + def check_output(expected: List[str]): def _check_inner(actual: List[PCollection[str]]): formatted_actual = actual | beam.Flatten() | beam.Map( @@ -59,7 +113,31 @@ def products_csv(): ]) -def spanner_data(): +def spanner_orders_data(): + return [{ + 'order_id': 1, + 'customer_id': 1001, + 'product_id': 2001, + 'order_date': '24-03-24', + 'order_amount': 150, + }, + { + 'order_id': 2, + 'customer_id': 1002, + 'product_id': 2002, + 'order_date': '19-04-24', + 'order_amount': 90, + }, + { + 'order_id': 3, + 'customer_id': 1003, + 'product_id': 2003, + 'order_date': '7-05-24', + 'order_amount': 110, + }] + + +def spanner_shipments_data(): return [{ 'shipment_id': 'S1', 'customer_id': 'C1', @@ -110,6 +188,44 @@ def spanner_data(): }] +def bigtable_data(): + return [{ + 'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2' + }, { + 'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4' + }, { + 'product_id': '3', 'product_name': 'pixel 7', 'product_stock': '20' + }, { + 'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10' + }, { + 'product_id': '5', 'product_name': 'pixel 11', 'product_stock': '3' + }, { + 'product_id': '6', 'product_name': 'pixel 12', 'product_stock': '7' + }, { + 'product_id': '7', 'product_name': 'pixel 13', 'product_stock': '8' + }, { + 'product_id': '8', 'product_name': 'pixel 14', 'product_stock': '3' + }] + + +def bigquery_data(): + return [{ + 'customer_id': 1001, + 'customer_name': 'Alice', + 'customer_email': 'alice@gmail.com' + }, + { + 'customer_id': 1002, + 'customer_name': 'Bob', + 'customer_email': 'bob@gmail.com' + }, + { + 'customer_id': 1003, + 'customer_name': 'Claire', + 'customer_email': 'claire@gmail.com' + }] + + def create_test_method( pipeline_spec_file: str, custom_preprocessors: List[Callable[..., Union[Dict, List]]]): @@ -135,7 +251,11 @@ def test_yaml_example(self): pickle_library='cloudpickle', **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( 'options', {})))) as p: - actual = [yaml_transform.expand_pipeline(p, pipeline_spec)] + actual = [ + yaml_transform.expand_pipeline( + p, + pipeline_spec, [yaml_provider.InlineProvider(TEST_PROVIDERS)]) + ] if not actual[0]: actual = list(p.transforms_stack[0].parts[-1].outputs.values()) for transform in p.transforms_stack[0].parts[:-1]: @@ -213,7 +333,8 @@ def _wordcount_test_preprocessor( 'test_simple_filter_yaml', 'test_simple_filter_and_combine_yaml', 'test_spanner_read_yaml', - 'test_spanner_write_yaml' + 'test_spanner_write_yaml', + 'test_enrich_spanner_with_bigquery_yaml' ]) def _io_write_test_preprocessor( test_spec: dict, expected: List[str], env: TestEnvironment): @@ -249,7 +370,8 @@ def _file_io_read_test_preprocessor( return test_spec -@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml']) +@YamlExamplesTestSuite.register_test_preprocessor( + ['test_spanner_read_yaml', 'test_enrich_spanner_with_bigquery_yaml']) def _spanner_io_read_test_preprocessor( test_spec: dict, expected: List[str], env: TestEnvironment): @@ -265,14 +387,42 @@ def _spanner_io_read_test_preprocessor( k: v for k, v in config.items() if k.startswith('__') } - transform['config']['elements'] = INPUT_TABLES[( - str(instance), str(database), str(table))] + elements = INPUT_TABLES[(str(instance), str(database), str(table))] + if config.get('query', None): + config['query'].replace('select ', + 'SELECT ').replace(' from ', ' FROM ') + columns = set( + ''.join(config['query'].split('SELECT ')[1:]).split( + ' FROM', maxsplit=1)[0].split(', ')) + if columns != {'*'}: + elements = [{ + column: element[column] + for column in element if column in columns + } for element in elements] + transform['config']['elements'] = elements + + return test_spec + + +@YamlExamplesTestSuite.register_test_preprocessor( + ['test_bigtable_enrichment_yaml', 'test_enrich_spanner_with_bigquery_yaml']) +def _enrichment_test_preprocessor( + test_spec: dict, expected: List[str], env: TestEnvironment): + if pipeline := test_spec.get('pipeline', None): + for transform in pipeline.get('transforms', []): + if transform.get('type', '').startswith('Enrichment'): + transform['type'] = 'TestEnrichment' return test_spec INPUT_FILES = {'products.csv': products_csv()} -INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()} +INPUT_TABLES = { + ('shipment-test', 'shipment', 'shipments'): spanner_shipments_data(), + ('orders-test', 'order-database', 'orders'): spanner_orders_data(), + ('BigTable', 'beam-test', 'bigtable-enrichment-test'): bigtable_data(), + ('BigQuery', 'ALL_TEST', 'customers'): bigquery_data() +} YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__)) ExamplesTest = YamlExamplesTestSuite( @@ -290,6 +440,10 @@ def _spanner_io_read_test_preprocessor( 'IOExamplesTest', os.path.join(YAML_DOCS_DIR, '../transforms/io/*.yaml')).run() +MLTest = YamlExamplesTestSuite( + 'MLExamplesTest', os.path.join(YAML_DOCS_DIR, + '../transforms/ml/*.yaml')).run() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/examples/transforms/ml/bigtable_enrichment.yaml b/sdks/python/apache_beam/yaml/examples/transforms/ml/bigtable_enrichment.yaml new file mode 100644 index 000000000000..788b69de7857 --- /dev/null +++ b/sdks/python/apache_beam/yaml/examples/transforms/ml/bigtable_enrichment.yaml @@ -0,0 +1,55 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +pipeline: + type: chain + transforms: + + # Step 1: Creating a collection of elements that needs + # to be enriched. Here we are simulating sales data + - type: Create + config: + elements: + - sale_id: 1 + customer_id: 1 + product_id: 1 + quantity: 1 + + # Step 2: Enriching the data with Bigtable + # This specific bigtable stores product data in the below format + # product:product_id, product:product_name, product:product_stock + - type: Enrichment + config: + enrichment_handler: 'BigTable' + handler_config: + project_id: 'apache-beam-testing' + instance_id: 'beam-test' + table_id: 'bigtable-enrichment-test' + row_key: 'product_id' + timeout: 30 + + # Step 3: Logging for testing + # This is a simple way to view the enriched data + # We can also store it somewhere like a json file + - type: LogForTesting + +options: + yaml_experimental_features: Enrichment + +# Expected: +# Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) \ No newline at end of file diff --git a/sdks/python/apache_beam/yaml/examples/transforms/ml/enrich_spanner_with_bigquery.yaml b/sdks/python/apache_beam/yaml/examples/transforms/ml/enrich_spanner_with_bigquery.yaml new file mode 100644 index 000000000000..e63b3105cc0c --- /dev/null +++ b/sdks/python/apache_beam/yaml/examples/transforms/ml/enrich_spanner_with_bigquery.yaml @@ -0,0 +1,102 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +pipeline: + transforms: + # Step 1: Read orders details from Spanner + - type: ReadFromSpanner + name: ReadOrders + config: + project_id: 'apache-beam-testing' + instance_id: 'orders-test' + database_id: 'order-database' + query: 'SELECT customer_id, product_id, order_date, order_amount FROM orders' + + # Step 2: Enrich order details with customers details from BigQuery + - type: Enrichment + name: Enriched + input: ReadOrders + config: + enrichment_handler: 'BigQuery' + handler_config: + project: "apache-beam-testing" + table_name: "apache-beam-testing.ALL_TEST.customers" + row_restriction_template: "customer_id = 1001 or customer_id = 1003" + fields: ["customer_id"] + + # Step 3: Map enriched values to Beam schema + # TODO: This should be removed when schema'd enrichment is available + - type: MapToFields + name: MapEnrichedValues + input: Enriched + config: + language: python + fields: + customer_id: + callable: 'lambda x: x.customer_id' + output_type: integer + customer_name: + callable: 'lambda x: x.customer_name' + output_type: string + customer_email: + callable: 'lambda x: x.customer_email' + output_type: string + product_id: + callable: 'lambda x: x.product_id' + output_type: integer + order_date: + callable: 'lambda x: x.order_date' + output_type: string + order_amount: + callable: 'lambda x: x.order_amount' + output_type: integer + + # Step 4: Filter orders with amount greater than 110 + - type: Filter + name: FilterHighValueOrders + input: MapEnrichedValues + config: + keep: "order_amount > 110" + language: "python" + + + # Step 6: Write processed order to another spanner table + # Note: Make sure to replace $VARS with your values. + - type: WriteToSpanner + name: WriteProcessedOrders + input: FilterHighValueOrders + config: + project_id: '$PROJECT' + instance_id: '$INSTANCE' + database_id: '$DATABASE' + table_id: '$TABLE' + error_handling: + output: my_error_output + + # Step 7: Handle write errors by writing to JSON + - type: WriteToJson + name: WriteErrorsToJson + input: WriteProcessedOrders.my_error_output + config: + path: 'errors.json' + +options: + yaml_experimental_features: Enrichment + +# Expected: +# Row(customer_id=1001, customer_name='Alice', customer_email='alice@gmail.com', product_id=2001, order_date='24-03-24', order_amount=150) diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 27e17029f387..693df6179a2d 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -194,12 +194,46 @@ def io_grouping_key(transform_name): SKIP = {} +def add_transform_links(transform, description, provider_list): + """ + Convert references of Providers to urls that link to their respective pages. + + For example, + "Some description talking about MyTransform." + would be converted to + "Some description talking about MyTransform" + + meanwhile:: + + type: MyTransform + config: + ... + + Would remain unchanged. + + Avoid self-linking within a Transform page. + """ + for p in provider_list: + # Match all instances of built-in transforms within the description + # excluding the transform whose description is currently being evaluated. + # Match the entire word boundary so that partial matches do not count. + # (i.e. OtherTransform should not match Transform) + description = re.sub( + rf"(?{p}', + description or '') + return description + + def transform_docs(transform_base, transforms, providers, extra_docs=''): return '\n'.join([ f'## {transform_base}', '', longest( - lambda t: longest(lambda p: p.description(t), providers[t]), + lambda t: longest( + lambda p: add_transform_links( + t, p.description(t), providers.keys()), + providers[t]), transforms).replace('::\n', '\n\n :::yaml\n'), '', extra_docs, @@ -250,7 +284,7 @@ def main(): if options.markdown_file or options.html_file: if '-' in transforms[0]: extra_docs = 'Supported languages: ' + ', '.join( - t.split('-')[-1] for t in sorted(transforms)) + t.split('-')[-1] for t in sorted(transforms)) + '.' else: extra_docs = '' markdown_out.write( diff --git a/sdks/python/apache_beam/yaml/main_test.py b/sdks/python/apache_beam/yaml/main_test.py index 1a3da6443b72..d5fbfedc0349 100644 --- a/sdks/python/apache_beam/yaml/main_test.py +++ b/sdks/python/apache_beam/yaml/main_test.py @@ -15,6 +15,7 @@ # limitations under the License. # +import datetime import glob import logging import os @@ -100,6 +101,18 @@ def test_preparse_jinja_flags(self): 'pos_arg', ]) + def test_jinja_datetime(self): + with tempfile.TemporaryDirectory() as tmpdir: + out_path = os.path.join(tmpdir, 'out.txt') + main.run([ + '--yaml_pipeline', + TEST_PIPELINE.replace('PATH', out_path).replace( + 'ELEMENT', '"{{datetime.datetime.now().strftime("%Y-%m-%d")}}"'), + ]) + with open(glob.glob(out_path + '*')[0], 'rt') as fin: + self.assertEqual( + fin.read().strip(), datetime.datetime.now().strftime("%Y-%m-%d")) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index 305e6877ad90..a21782bdc603 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -194,38 +194,43 @@ transforms: 'ReadFromJdbc': 'ReadFromJdbc' 'WriteToJdbc': 'WriteToJdbc' - 'ReadFromMySql': 'ReadFromJdbc' - 'WriteToMySql': 'WriteToJdbc' - 'ReadFromPostgres': 'ReadFromJdbc' - 'WriteToPostgres': 'WriteToJdbc' - 'ReadFromOracle': 'ReadFromJdbc' - 'WriteToOracle': 'WriteToJdbc' - 'ReadFromSqlServer': 'ReadFromJdbc' - 'WriteToSqlServer': 'WriteToJdbc' + 'ReadFromMySql': 'ReadFromMySql' + 'WriteToMySql': 'WriteToMySql' + 'ReadFromPostgres': 'ReadFromPostgres' + 'WriteToPostgres': 'WriteToPostgres' + 'ReadFromOracle': 'ReadFromOracle' + 'WriteToOracle': 'WriteToOracle' + 'ReadFromSqlServer': 'ReadFromSqlServer' + 'WriteToSqlServer': 'WriteToSqlServer' config: mappings: 'ReadFromJdbc': - driver_class_name: 'driver_class_name' - type: 'jdbc_type' url: 'jdbc_url' - username: 'username' - password: 'password' - table: 'location' - query: 'read_query' - driver_jars: 'driver_jars' - connection_properties: 'connection_properties' connection_init_sql: 'connection_init_sql' - 'WriteToJdbc': + connection_properties: 'connection_properties' + disable_auto_commit: 'disable_auto_commit' driver_class_name: 'driver_class_name' + driver_jars: 'driver_jars' + fetch_size: 'fetch_size' + output_parallelization: 'output_parallelization' + password: 'password' + query: 'read_query' + table: 'location' type: 'jdbc_type' - url: 'jdbc_url' username: 'username' + 'WriteToJdbc': + url: 'jdbc_url' + auto_sharding: 'autosharding' + connection_init_sql: 'connection_init_sql' + connection_properties: 'connection_properties' + driver_class_name: 'driver_class_name' + driver_jars: 'driver_jars' password: 'password' table: 'location' - driver_jars: 'driver_jars' - connection_properties: 'connection_properties' - connection_init_sql: 'connection_init_sql' batch_size: 'batch_size' + type: 'jdbc_type' + username: 'username' + query: 'write_statement' 'ReadFromMySql': 'ReadFromJdbc' 'WriteToMySql': 'WriteToJdbc' 'ReadFromPostgres': 'ReadFromJdbc' @@ -236,26 +241,56 @@ 'WriteToSqlServer': 'WriteToJdbc' defaults: 'ReadFromMySql': - jdbc_type: 'mysql' + driver_class_name: '' + driver_jars: '' + jdbc_type: '' 'WriteToMySql': - jdbc_type: 'mysql' + driver_class_name: '' + driver_jars: '' + jdbc_type: '' 'ReadFromPostgres': - jdbc_type: 'postgres' + connection_init_sql: '' + driver_class_name: '' + driver_jars: '' + jdbc_type: '' 'WriteToPostgres': - jdbc_type: 'postgres' + connection_init_sql: '' + driver_class_name: '' + driver_jars: '' + jdbc_type: '' 'ReadFromOracle': - jdbc_type: 'oracle' + connection_init_sql: '' + driver_class_name: '' + driver_jars: '' + jdbc_type: '' 'WriteToOracle': - jdbc_type: 'oracle' + connection_init_sql: '' + driver_class_name: '' + driver_jars: '' + jdbc_type: '' 'ReadFromSqlServer': - jdbc_type: 'mssql' + connection_init_sql: '' + driver_class_name: '' + driver_jars: '' + jdbc_type: '' 'WriteToSqlServer': - jdbc_type: 'mssql' + connection_init_sql: '' + driver_class_name: '' + driver_jars: '' + jdbc_type: '' underlying_provider: type: beamJar transforms: 'ReadFromJdbc': 'beam:schematransform:org.apache.beam:jdbc_read:v1' + 'ReadFromMySql': 'beam:schematransform:org.apache.beam:mysql_read:v1' + 'ReadFromPostgres': 'beam:schematransform:org.apache.beam:postgres_read:v1' + 'ReadFromOracle': 'beam:schematransform:org.apache.beam:oracle_read:v1' + 'ReadFromSqlServer': 'beam:schematransform:org.apache.beam:sql_server_read:v1' 'WriteToJdbc': 'beam:schematransform:org.apache.beam:jdbc_write:v1' + 'WriteToMySql': 'beam:schematransform:org.apache.beam:mysql_write:v1' + 'WriteToPostgres': 'beam:schematransform:org.apache.beam:postgres_write:v1' + 'WriteToOracle': 'beam:schematransform:org.apache.beam:oracle_write:v1' + 'WriteToSqlServer': 'beam:schematransform:org.apache.beam:sql_server_write:v1' config: gradle_target: 'sdks:java:extensions:schemaio-expansion-service:shadowJar' diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 242faaa9a77b..31eb5e1c6daa 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -56,6 +56,7 @@ config: {} transforms: MLTransform: 'apache_beam.yaml.yaml_ml.ml_transform' + RunInference: 'apache_beam.yaml.yaml_ml.run_inference' - type: renaming transforms: diff --git a/sdks/python/apache_beam/yaml/yaml_combine.py b/sdks/python/apache_beam/yaml/yaml_combine.py index a28bef52ea31..b7499f3b0c7a 100644 --- a/sdks/python/apache_beam/yaml/yaml_combine.py +++ b/sdks/python/apache_beam/yaml/yaml_combine.py @@ -94,6 +94,12 @@ class PyJsYamlCombine(beam.PTransform): See also the documentation on [YAML Aggregation](https://beam.apache.org/documentation/sdks/yaml-combine/). + + Args: + group_by: The field(s) to aggregate on. + combine: The aggregation function to use. + language: The language used to define (and execute) the + custom callables in `combine`. Defaults to generic. """ def __init__( self, diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 9f92f59f42b6..7f7da7aca6a9 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -23,6 +23,7 @@ from typing import Callable from typing import Collection from typing import Dict +from typing import Iterable from typing import List from typing import Mapping from typing import Optional @@ -428,19 +429,19 @@ class _StripErrorMetadata(beam.PTransform): For example, in the following pipeline snippet:: - - name: MyMappingTransform - type: MapToFields - input: SomeInput - config: - language: python - fields: - ... - error_handling: - output: errors + - name: MyMappingTransform + type: MapToFields + input: SomeInput + config: + language: python + fields: + ... + error_handling: + output: errors - - name: RecoverOriginalElements - type: StripErrorMetadata - input: MyMappingTransform.errors + - name: RecoverOriginalElements + type: StripErrorMetadata + input: MyMappingTransform.errors the output of `RecoverOriginalElements` will contain exactly those elements from SomeInput that failed to processes (whereas `MyMappingTransform.errors` @@ -453,6 +454,9 @@ class _StripErrorMetadata(beam.PTransform): _ERROR_FIELD_NAMES = ('failed_row', 'element', 'record') + def __init__(self): + super().__init__(label=None) + def expand(self, pcoll): try: existing_fields = { @@ -616,6 +620,13 @@ def _PyJsFilter( See more complete documentation on [YAML Filtering](https://beam.apache.org/documentation/sdks/yaml-udf/#filtering). + + Args: + keep: An expression evaluating to true for those records that should be kept. + language: The language of the above expression. + Defaults to generic. + error_handling: Whether and where to output records that throw errors when + the above expressions are evaluated. """ # pylint: disable=line-too-long keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language or 'generic') return pcoll | beam.Filter(keep_fn) @@ -661,14 +672,32 @@ def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'): @beam.ptransform.ptransform_fn @maybe_with_exception_handling_transform_fn -def _PyJsMapToFields(pcoll, language='generic', **mapping_args): +def _PyJsMapToFields( + pcoll, + fields: Mapping[str, Union[str, Mapping[str, str]]], + append: Optional[bool] = False, + drop: Optional[Iterable[str]] = None, + language: Optional[str] = None): """Creates records with new fields defined in terms of the input fields. See more complete documentation on [YAML Mapping Functions](https://beam.apache.org/documentation/sdks/yaml-udf/#mapping-functions). + + Args: + fields: The output fields to compute, each mapping to the expression or + callable that creates them. + append: Whether to append the created fields to the set of + fields already present, outputting a union of both the new fields and + the original fields for each record. Defaults to False. + drop: If `append` is true, enumerates a subset of fields from the + original record that should not be kept + language: The language used to define (and execute) the + expressions and/or callables in `fields`. Defaults to generic. + error_handling: Whether and where to output records that throw errors when + the above expressions are evaluated. """ # pylint: disable=line-too-long input_schema, fields = normalize_fields( - pcoll, language=language, **mapping_args) + pcoll, fields, drop or (), append, language=language or 'generic') if language == 'javascript': options.YamlOptions.check_enabled(pcoll.pipeline, 'javascript') diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 33f2eeefd296..e958ea70aff8 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -16,13 +16,20 @@ # """This module defines yaml wrappings for some ML transforms.""" - from typing import Any +from typing import Callable +from typing import Dict from typing import List from typing import Optional import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference import RunInference +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.typehints.row_type import RowTypeConstraint +from apache_beam.utils import python_callable from apache_beam.yaml import options +from apache_beam.yaml.yaml_utils import SafeLineLoader try: from apache_beam.ml.transforms import tft @@ -33,11 +40,436 @@ tft = None # type: ignore +class ModelHandlerProvider: + handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} + + def __init__( + self, + handler, + preprocess: Optional[Dict[str, str]] = None, + postprocess: Optional[Dict[str, str]] = None): + self._handler = handler + self._preprocess_fn = self.parse_processing_transform( + preprocess, 'preprocess') or self.default_preprocess_fn() + self._postprocess_fn = self.parse_processing_transform( + postprocess, 'postprocess') or self.default_postprocess_fn() + + def inference_output_type(self): + return Any + + @staticmethod + def parse_processing_transform(processing_transform, typ): + def _parse_config(callable=None, path=None, name=None): + if callable and (path or name): + raise ValueError( + f"Cannot specify 'callable' with 'path' and 'name' for {typ} " + f"function.") + if path and name: + return python_callable.PythonCallableWithSource.load_from_script( + FileSystems.open(path).read().decode(), name) + elif callable: + return python_callable.PythonCallableWithSource(callable) + else: + raise ValueError( + f"Must specify one of 'callable' or 'path' and 'name' for {typ} " + f"function.") + + if processing_transform: + if isinstance(processing_transform, dict): + return _parse_config(**processing_transform) + else: + raise ValueError("Invalid model_handler specification.") + + def underlying_handler(self): + return self._handler + + @staticmethod + def default_preprocess_fn(): + raise ValueError( + 'Model Handler does not implement a default preprocess ' + 'method. Please define a preprocessing method using the ' + '\'preprocess\' tag. This is required in most cases because ' + 'most models will have a different input shape, so the model ' + 'cannot generalize how the input Row should be transformed. For ' + 'an example preprocess method, see VertexAIModelHandlerJSONProvider') + + def _preprocess_fn_internal(self): + return lambda row: (row, self._preprocess_fn(row)) + + @staticmethod + def default_postprocess_fn(): + return lambda x: x + + def _postprocess_fn_internal(self): + return lambda result: (result[0], self._postprocess_fn(result[1])) + + @staticmethod + def validate(model_handler_spec): + raise NotImplementedError(type(ModelHandlerProvider)) + + @classmethod + def register_handler_type(cls, type_name): + def apply(constructor): + cls.handler_types[type_name] = constructor + return constructor + + return apply + + @classmethod + def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider": + typ = model_handler_spec['type'] + config = model_handler_spec['config'] + try: + result = cls.handler_types[typ](**config) + if not hasattr(result, 'to_json'): + result.to_json = lambda: model_handler_spec + return result + except Exception as exn: + raise ValueError( + f'Unable to instantiate model handler of type {typ}. {exn}') + + +@ModelHandlerProvider.register_handler_type('VertexAIModelHandlerJSON') +class VertexAIModelHandlerJSONProvider(ModelHandlerProvider): + def __init__( + self, + endpoint_id: str, + project: str, + location: str, + preprocess: Dict[str, str], + postprocess: Optional[Dict[str, str]] = None, + experiment: Optional[str] = None, + network: Optional[str] = None, + private: bool = False, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + env_vars: Optional[Dict[str, Any]] = None): + """ + ModelHandler for Vertex AI. + + This Model Handler can be used with RunInference to load a model hosted + on VertexAI. Every model that is hosted on VertexAI should have three + distinct, required, parameters - `endpoint_id`, `project` and `location`. + These parameters tell the Model Handler how to access the model's endpoint + so that input data can be sent using an API request, and inferences can be + received as a response. + + This Model Handler also requires a `preprocess` function to be defined. + Preprocessing and Postprocessing are described in more detail in the + RunInference docs: + https://beam.apache.org/releases/yamldoc/current/#runinference + + Every model will have a unique input, but all requests should be + JSON-formatted. For example, most language models such as Llama and Gemma + expect a JSON with the key "prompt" (among other optional keys). In Python, + JSON can be expressed as a dictionary. + + For example: :: + + - type: RunInference + config: + inference_tag: 'my_inference' + model_handler: + type: VertexAIModelHandlerJSON + config: + endpoint_id: 9876543210 + project: my-project + location: us-east1 + preprocess: + callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}' + + In the above example, which mimics a call to a Llama 3 model hosted on + VertexAI, the preprocess function (in this case a lambda) takes in a Beam + Row with a single field, "prompt", and maps it to a dict with the same + field. It also specifies an optional parameter, "max_tokens", that tells the + model the allowed token size (in this case input + output token size). + + Args: + endpoint_id: the numerical ID of the Vertex AI endpoint to query. + project: the GCP project name where the endpoint is deployed. + location: the GCP location where the endpoint is deployed. + preprocess: A python callable, defined either inline, or using a file, + that is invoked on the input row before sending to the model to be + loaded by this ModelHandler. This parameter is required by the + `VertexAIModelHandlerJSON` ModelHandler. + postprocess: A python callable, defined either inline, or using a file, + that is invoked on the PredictionResult output by the ModelHandler + before parsing into the output Beam Row under the field name defined + by the inference_tag. + experiment: Experiment label to apply to the + queries. See + https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments + for more information. + network: The full name of the Compute Engine + network the endpoint is deployed on; used for private + endpoints. The network or subnetwork Dataflow pipeline + option must be set and match this network for pipeline + execution. + Ex: "projects/12345/global/networks/myVPC" + private: If the deployed Vertex AI endpoint is + private, set to true. Requires a network to be provided + as well. + min_batch_size: The minimum batch size to use when batching + inputs. + max_batch_size: The maximum batch size to use when batching + inputs. + max_batch_duration_secs: The maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + env_vars: Environment variables. + """ + + try: + from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON + except ImportError: + raise ValueError( + 'Unable to import VertexAIModelHandlerJSON. Please ' + 'install gcp dependencies: `pip install apache_beam[gcp]`') + + _handler = VertexAIModelHandlerJSON( + endpoint_id=str(endpoint_id), + project=project, + location=location, + experiment=experiment, + network=network, + private=private, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + env_vars=env_vars or {}) + + super().__init__(_handler, preprocess, postprocess) + + @staticmethod + def validate(model_handler_spec): + pass + + def inference_output_type(self): + return RowTypeConstraint.from_fields([('example', Any), ('inference', Any), + ('model_id', Optional[str])]) + + +def get_user_schema_fields(user_type): + return [(name, type(typ) if not isinstance(typ, type) else typ) + for (name, typ) in user_type._fields] if user_type else [] + + +@beam.ptransform.ptransform_fn +def run_inference( + pcoll, + model_handler: Dict[str, Any], + inference_tag: Optional[str] = 'inference', + inference_args: Optional[Dict[str, Any]] = None) -> beam.PCollection[beam.Row]: # pylint: disable=line-too-long + """ + A transform that takes the input rows, containing examples (or features), for + use on an ML model. The transform then appends the inferences + (or predictions) for those examples to the input row. + + A ModelHandler must be passed to the `model_handler` parameter. The + ModelHandler is responsible for configuring how the ML model will be loaded + and how input data will be passed to it. Every ModelHandler has a config tag, + similar to how a transform is defined, where the parameters are defined. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + ... + + By default, the RunInference transform will return the + input row with a single field appended named by the `inference_tag` parameter + ("inference" by default) that contains the inference directly returned by the + underlying ModelHandler, after any optional postprocessing. + + For example, if the input had the following: :: + + Row(question="What is a car?") + + The output row would look like: :: + + Row(question="What is a car?", inference=...) + + where the `inference` tag can be overridden with the `inference_tag` + parameter. + + However, if one specified the following transform config: :: + + - type: RunInference + config: + inference_tag: my_inference + model_handler: ... + + The output row would look like: :: + + Row(question="What is a car?", my_inference=...) + + See more complete documentation on the underlying + [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) + transform. + + ### Preprocessing input data + + In most cases, the model will be expecting data in a particular data format, + whether it be a Python Dict, PyTorch tensor, etc. However, the outputs of all + built-in Beam YAML transforms are Beam Rows. To allow for transforming + the Beam Row into a data format the model recognizes, each ModelHandler is + equipped with a `preprocessing` parameter for performing necessary data + preprocessing. It is possible for a ModelHandler to define a default + preprocessing function, but in most cases, one will need to be specified by + the caller. + + For example, using `callable`: :: + + pipeline: + type: chain + + transforms: + - type: Create + config: + elements: + - question: "What is a car?" + - question: "Where is the Eiffel Tower located?" + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + callable: 'lambda row: {"prompt": row.question}' + ... + + In the above example, the Create transform generates a collection of two Beam + Row elements, each with a single field - "question". The model, however, + expects a Python Dict with a single key, "prompt". In this case, we can + specify a simple Lambda function (alternatively could define a full function), + to map the data. + + ### Postprocessing predictions + + It is also possible to define a postprocessing function to postprocess the + data output by the ModelHandler. See the documentation for the ModelHandler + you intend to use (list defined below under `model_handler` parameter doc). + + In many cases, before postprocessing, the object + will be a + [PredictionResult](https://beam.apache.org/releases/pydoc/BEAM_VERSION/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult). # pylint: disable=line-too-long + This type behaves very similarly to a Beam Row and fields can be accessed + using dot notation. However, make sure to check the docs for your ModelHandler + to see which fields its PredictionResult contains or if it returns a + different object altogether. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + postprocess: + callable: | + def fn(x: PredictionResult): + return beam.Row(x.example, x.inference, x.model_id) + ... + + The above example demonstrates converting the original output data type (in + this case it is PredictionResult), and converts to a Beam Row, which allows + for easier mapping in a later transform. + + ### File-based pre/postprocessing functions + + For both preprocessing and postprocessing, it is also possible to specify a + Python UDF (User-defined function) file that contains the function. This is + possible by specifying the `path` to the file (local file or GCS path) and + the `name` of the function in the file. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + path: gs://my-bucket/path/to/preprocess.py + name: my_preprocess_fn + postprocess: + path: gs://my-bucket/path/to/postprocess.py + name: my_postprocess_fn + ... + + Args: + model_handler: Specifies the parameters for the respective + enrichment_handler in a YAML/JSON format. To see the full set of + handler_config parameters, see their corresponding doc pages: + + - [VertexAIModelHandlerJSON](https://beam.apache.org/releases/pydoc/current/apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider) # pylint: disable=line-too-long + inference_tag: The tag to use for the returned inference. Default is + 'inference'. + inference_args: Extra arguments for models whose inference call requires + extra parameters. Make sure to check the underlying ModelHandler docs to + see which args are allowed. + + """ + + options.YamlOptions.check_enabled(pcoll.pipeline, 'ML') + + if not isinstance(model_handler, dict): + raise ValueError( + 'Invalid model_handler specification. Expected dict but was ' + f'{type(model_handler)}.') + expected_model_handler_params = {'type', 'config'} + given_model_handler_params = set( + SafeLineLoader.strip_metadata(model_handler).keys()) + extra_params = given_model_handler_params - expected_model_handler_params + if extra_params: + raise ValueError(f'Unexpected parameters in model_handler: {extra_params}') + missing_params = expected_model_handler_params - given_model_handler_params + if missing_params: + raise ValueError(f'Missing parameters in model_handler: {missing_params}') + typ = model_handler['type'] + model_handler_provider_type = ModelHandlerProvider.handler_types.get( + typ, None) + if not model_handler_provider_type: + raise NotImplementedError(f'Unknown model handler type: {typ}.') + + model_handler_provider = ModelHandlerProvider.create_handler(model_handler) + model_handler_provider.validate(model_handler['config']) + user_type = RowTypeConstraint.from_user_type(pcoll.element_type.user_type) + schema = RowTypeConstraint.from_fields( + get_user_schema_fields(user_type) + + [(str(inference_tag), model_handler_provider.inference_output_type())]) + + return ( + pcoll | RunInference( + model_handler=KeyedModelHandler( + model_handler_provider.underlying_handler()).with_preprocess_fn( + model_handler_provider._preprocess_fn_internal()). + with_postprocess_fn( + model_handler_provider._postprocess_fn_internal()), + inference_args=inference_args) + | beam.Map( + lambda row: beam.Row(**{ + inference_tag: row[1], **row[0]._asdict() + })).with_output_types(schema)) + + def _config_to_obj(spec): if 'type' not in spec: - raise ValueError(r"Missing type in ML transform spec {spec}") + raise ValueError(f"Missing type in ML transform spec {spec}") if 'config' not in spec: - raise ValueError(r"Missing config in ML transform spec {spec}") + raise ValueError(f"Missing config in ML transform spec {spec}") constructor = _transform_constructors.get(spec['type']) if constructor is None: raise ValueError("Unknown ML transform type: %r" % spec['type']) diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index a07638953551..b3518c568653 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -36,6 +36,7 @@ from typing import Dict from typing import Iterable from typing import Iterator +from typing import List from typing import Mapping from typing import Optional @@ -349,7 +350,7 @@ def python(urns, packages=()): @ExternalProvider.register_provider_type('pythonPackage') class ExternalPythonProvider(ExternalProvider): - def __init__(self, urns, packages): + def __init__(self, urns, packages: Iterable[str]): super().__init__(urns, PypiExpansionService(packages)) def available(self): @@ -907,7 +908,7 @@ def log_for_testing( def to_loggable_json_recursive(o): if isinstance(o, (str, bytes)): - return o + return str(o) elif callable(getattr(o, '_asdict', None)): return to_loggable_json_recursive(o._asdict()) elif isinstance(o, Mapping) and callable(getattr(o, 'items', None)): @@ -1017,26 +1018,31 @@ class PypiExpansionService: """ VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/venvs") - def __init__(self, packages, base_python=sys.executable): - self._packages = packages + def __init__( + self, packages: Iterable[str], base_python: str = sys.executable): + if not isinstance(packages, Iterable) or isinstance(packages, str): + raise TypeError( + "Packages must be an iterable of strings, got %r" % packages) + self._packages = list(packages) self._base_python = base_python @classmethod - def _key(cls, base_python, packages): + def _key(cls, base_python: str, packages: List[str]) -> str: return json.dumps({ 'binary': base_python, 'packages': sorted(packages) }, sort_keys=True) @classmethod - def _path(cls, base_python, packages): + def _path(cls, base_python: str, packages: List[str]) -> str: return os.path.join( cls.VENV_CACHE, hashlib.sha256(cls._key(base_python, packages).encode('utf-8')).hexdigest()) @classmethod - def _create_venv_from_scratch(cls, base_python, packages): + def _create_venv_from_scratch( + cls, base_python: str, packages: List[str]) -> str: venv = cls._path(base_python, packages) if not os.path.exists(venv): try: @@ -1054,7 +1060,8 @@ def _create_venv_from_scratch(cls, base_python, packages): return venv @classmethod - def _create_venv_from_clone(cls, base_python, packages): + def _create_venv_from_clone( + cls, base_python: str, packages: List[str]) -> str: venv = cls._path(base_python, packages) if not os.path.exists(venv): try: @@ -1074,7 +1081,7 @@ def _create_venv_from_clone(cls, base_python, packages): return venv @classmethod - def _create_venv_to_clone(cls, base_python): + def _create_venv_to_clone(cls, base_python: str) -> str: if '.dev' in beam_version: base_venv = os.path.dirname(os.path.dirname(base_python)) print('Cloning dev environment from', base_venv) @@ -1085,7 +1092,7 @@ def _create_venv_to_clone(cls, base_python): 'virtualenv-clone' ]) - def _venv(self): + def _venv(self) -> str: return self._create_venv_from_clone(self._base_python, self._packages) def __enter__(self): diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index b8e49e81c579..12161d3d580d 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -16,13 +16,13 @@ # import collections +import datetime import functools import json import logging import os import pprint import re -import uuid from typing import Any from typing import Iterable from typing import List @@ -31,7 +31,6 @@ import jinja2 import yaml -from yaml.loader import SafeLoader import apache_beam as beam from apache_beam.io.filesystems import FileSystems @@ -41,6 +40,7 @@ from apache_beam.yaml.yaml_combine import normalize_combine from apache_beam.yaml.yaml_mapping import normalize_mapping from apache_beam.yaml.yaml_mapping import validate_generic_expressions +from apache_beam.yaml.yaml_utils import SafeLineLoader __all__ = ["YamlTransform"] @@ -129,59 +129,6 @@ def empty_if_explicitly_empty(io): return io -class SafeLineLoader(SafeLoader): - """A yaml loader that attaches line information to mappings and strings.""" - class TaggedString(str): - """A string class to which we can attach metadata. - - This is primarily used to trace a string's origin back to its place in a - yaml file. - """ - def __reduce__(self): - # Pickle as an ordinary string. - return str, (str(self), ) - - def construct_scalar(self, node): - value = super().construct_scalar(node) - if isinstance(value, str): - value = SafeLineLoader.TaggedString(value) - value._line_ = node.start_mark.line + 1 - return value - - def construct_mapping(self, node, deep=False): - mapping = super().construct_mapping(node, deep=deep) - mapping['__line__'] = node.start_mark.line + 1 - mapping['__uuid__'] = self.create_uuid() - return mapping - - @classmethod - def create_uuid(cls): - return str(uuid.uuid4()) - - @classmethod - def strip_metadata(cls, spec, tagged_str=True): - if isinstance(spec, Mapping): - return { - cls.strip_metadata(key, tagged_str): - cls.strip_metadata(value, tagged_str) - for (key, value) in spec.items() - if key not in ('__line__', '__uuid__') - } - elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): - return [cls.strip_metadata(value, tagged_str) for value in spec] - elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: - return str(spec) - else: - return spec - - @staticmethod - def get_line(obj): - if isinstance(obj, dict): - return obj.get('__line__', 'unknown') - else: - return getattr(obj, '_line_', 'unknown') - - class LightweightScope(object): def __init__(self, transforms): self._transforms = transforms @@ -955,6 +902,21 @@ def preprocess_languages(spec): else: return spec + def validate_transform_references(spec): + name = spec.get('name', '') + transform_type = spec.get('type') + inputs = spec.get('input').get('input', []) + + if not is_empty(inputs): + input_values = [inputs] if isinstance(inputs, str) else inputs + for input_value in input_values: + if input_value in (name, transform_type): + raise ValueError( + f"Circular reference detected: Transform {name} " + f"references itself as input in {identify_object(spec)}") + + return spec + for phase in [ ensure_transforms_have_types, normalize_mapping, @@ -965,6 +927,7 @@ def preprocess_languages(spec): preprocess_chain, tag_explicit_inputs, normalize_inputs_outputs, + validate_transform_references, preprocess_flattened_inputs, ensure_errors_consumed, preprocess_windowing, @@ -992,7 +955,7 @@ def expand_jinja( jinja2.Environment( undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader()) .from_string(jinja_template) - .render(**jinja_variables)) + .render(datetime=datetime, **jinja_variables)) class YamlTransform(beam.PTransform): diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 7fcea7e2b662..b9caca4ca9f4 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -259,6 +259,51 @@ def test_csv_to_json(self): lines=True).sort_values('rank').reindex() pd.testing.assert_frame_equal(data, result) + def test_circular_reference_validation(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + # pylint: disable=expression-not-assigned + with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'): + p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create + config: + elements: [0, 1, 3, 4] + input: Create + - type: PyMap + name: PyMap + config: + fn: "lambda row: row.element * row.element" + input: Create + output: PyMap + ''', + providers=TEST_PROVIDERS) + + def test_circular_reference_multi_inputs_validation(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + # pylint: disable=expression-not-assigned + with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'): + p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create + config: + elements: [0, 1, 3, 4] + - type: PyMap + name: PyMap + config: + fn: "lambda row: row.element * row.element" + input: [Create, PyMap] + output: PyMap + ''', + providers=TEST_PROVIDERS) + def test_name_is_not_ambiguous(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: @@ -285,7 +330,7 @@ def test_name_is_ambiguous(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: # pylint: disable=expression-not-assigned - with self.assertRaisesRegex(ValueError, r'Ambiguous.*'): + with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'): p | YamlTransform( ''' type: composite diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 084e03cdb197..5bc9de24bb38 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -23,7 +23,6 @@ from apache_beam.yaml import YamlTransform from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_provider import InlineProvider -from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import Scope from apache_beam.yaml.yaml_transform import chain_as_composite from apache_beam.yaml.yaml_transform import ensure_errors_consumed @@ -39,57 +38,7 @@ from apache_beam.yaml.yaml_transform import preprocess_flattened_inputs from apache_beam.yaml.yaml_transform import preprocess_windowing from apache_beam.yaml.yaml_transform import push_windowing_to_roots - - -class SafeLineLoaderTest(unittest.TestCase): - def test_get_line(self): - pipeline_yaml = ''' - type: composite - input: - elements: input - transforms: - - type: PyMap - name: Square - input: elements - config: - fn: "lambda x: x * x" - - type: PyMap - name: Cube - input: elements - config: - fn: "lambda x: x * x * x" - output: - Flatten - ''' - spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) - self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) - self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) - self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) - self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") - - def test_strip_metadata(self): - spec_yaml = ''' - transforms: - - type: PyMap - name: Square - ''' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['transforms']) - - self.assertFalse(hasattr(stripped[0], '__line__')) - self.assertFalse(hasattr(stripped[0], '__uuid__')) - - def test_strip_metadata_nothing_to_strip(self): - spec_yaml = 'prop: 123' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['prop']) - - self.assertFalse(hasattr(stripped, '__line__')) - self.assertFalse(hasattr(stripped, '__uuid__')) +from apache_beam.yaml.yaml_utils import SafeLineLoader def new_pipeline(): diff --git a/sdks/python/apache_beam/yaml/yaml_utils.py b/sdks/python/apache_beam/yaml/yaml_utils.py new file mode 100644 index 000000000000..63beb90f0711 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import uuid +from typing import Iterable +from typing import Mapping + +from yaml import SafeLoader + + +class SafeLineLoader(SafeLoader): + """A yaml loader that attaches line information to mappings and strings.""" + class TaggedString(str): + """A string class to which we can attach metadata. + + This is primarily used to trace a string's origin back to its place in a + yaml file. + """ + def __reduce__(self): + # Pickle as an ordinary string. + return str, (str(self), ) + + def construct_scalar(self, node): + value = super().construct_scalar(node) + if isinstance(value, str): + value = SafeLineLoader.TaggedString(value) + value._line_ = node.start_mark.line + 1 + return value + + def construct_mapping(self, node, deep=False): + mapping = super().construct_mapping(node, deep=deep) + mapping['__line__'] = node.start_mark.line + 1 + mapping['__uuid__'] = self.create_uuid() + return mapping + + @classmethod + def create_uuid(cls): + return str(uuid.uuid4()) + + @classmethod + def strip_metadata(cls, spec, tagged_str=True): + if isinstance(spec, Mapping): + return { + cls.strip_metadata(key, tagged_str): + cls.strip_metadata(value, tagged_str) + for (key, value) in spec.items() + if key not in ('__line__', '__uuid__') + } + elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): + return [cls.strip_metadata(value, tagged_str) for value in spec] + elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: + return str(spec) + else: + return spec + + @staticmethod + def get_line(obj): + if isinstance(obj, dict): + return obj.get('__line__', 'unknown') + else: + return getattr(obj, '_line_', 'unknown') diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py new file mode 100644 index 000000000000..4fd2c793e57e --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import unittest + +import yaml + +from apache_beam.yaml.yaml_utils import SafeLineLoader + + +class SafeLineLoaderTest(unittest.TestCase): + def test_get_line(self): + pipeline_yaml = ''' + type: composite + input: + elements: input + transforms: + - type: PyMap + name: Square + input: elements + config: + fn: "lambda x: x * x" + - type: PyMap + name: Cube + input: elements + config: + fn: "lambda x: x * x * x" + output: + Flatten + ''' + spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) + self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) + self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) + self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) + self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") + + def test_strip_metadata(self): + spec_yaml = ''' + transforms: + - type: PyMap + name: Square + ''' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['transforms']) + + self.assertFalse(hasattr(stripped[0], '__line__')) + self.assertFalse(hasattr(stripped[0], '__uuid__')) + + def test_strip_metadata_nothing_to_strip(self): + spec_yaml = 'prop: 123' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['prop']) + + self.assertFalse(hasattr(stripped, '__line__')) + self.assertFalse(hasattr(stripped, '__uuid__')) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index 3442b92f3583..d3c4c3a76231 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -23,20 +23,20 @@ annotated-types==0.7.0 async-timeout==5.0.1 -attrs==24.2.0 +attrs==24.3.0 backports.tarfile==1.2.0 beautifulsoup4==4.12.3 bs4==0.0.2 build==1.2.2.post1 cachetools==5.5.0 -certifi==2024.8.30 +certifi==2024.12.14 cffi==1.17.1 -charset-normalizer==3.4.0 -click==8.1.7 +charset-normalizer==3.4.1 +click==8.1.8 cloudpickle==2.2.1 -cramjam==2.9.0 +cramjam==2.9.1 crcmod==1.7 -cryptography==43.0.3 +cryptography==44.0.0 Cython==3.0.11 Deprecated==1.2.15 deprecation==2.1.0 @@ -47,32 +47,32 @@ docopt==0.6.2 docstring_parser==0.16 exceptiongroup==1.2.2 execnet==2.1.1 -fastavro==1.9.7 +fastavro==1.10.0 fasteners==0.19 freezegun==1.5.1 future==1.0.0 -google-api-core==2.23.0 -google-api-python-client==2.153.0 +google-api-core==2.24.0 +google-api-python-client==2.156.0 google-apitools==0.5.31 -google-auth==2.36.0 +google-auth==2.37.0 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.72.0 +google-cloud-aiplatform==1.75.0 google-cloud-bigquery==3.27.0 google-cloud-bigquery-storage==2.27.0 google-cloud-bigtable==2.27.0 google-cloud-core==2.4.1 -google-cloud-datastore==2.20.1 -google-cloud-dlp==3.25.1 -google-cloud-language==2.15.1 +google-cloud-datastore==2.20.2 +google-cloud-dlp==3.26.0 +google-cloud-language==2.16.0 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.27.1 google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.14 -google-cloud-resource-manager==1.13.1 -google-cloud-spanner==3.50.1 -google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.14.1 -google-cloud-vision==3.8.1 +google-cloud-recommendations-ai==0.10.15 +google-cloud-resource-manager==1.14.0 +google-cloud-spanner==3.51.0 +google-cloud-storage==2.19.0 +google-cloud-videointelligence==2.15.0 +google-cloud-vision==3.9.0 google-crc32c==1.6.0 google-resumable-media==2.7.2 googleapis-common-protos==1.66.0 @@ -80,11 +80,11 @@ greenlet==3.1.1 grpc-google-iam-v1==0.13.1 grpc-interceptor==0.15.4 grpcio==1.65.5 -grpcio-status==1.62.3 +grpcio-status==1.65.5 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.119.1 +hypothesis==6.123.2 idna==3.10 importlib_metadata==8.5.0 iniconfig==2.0.0 @@ -92,12 +92,12 @@ jaraco.classes==3.4.0 jaraco.context==6.0.1 jaraco.functools==4.1.0 jeepney==0.8.0 -Jinja2==3.1.4 +Jinja2==3.1.5 joblib==1.4.2 jsonpickle==3.4.2 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 -keyring==25.5.0 +keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 mmh3==5.0.1 @@ -108,25 +108,25 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -opentelemetry-api==1.28.1 -opentelemetry-sdk==1.28.1 -opentelemetry-semantic-conventions==0.49b1 -orjson==3.10.11 +opentelemetry-api==1.29.0 +opentelemetry-sdk==1.29.0 +opentelemetry-semantic-conventions==0.50b0 +orjson==3.10.12 overrides==7.7.0 packaging==24.2 pandas==2.1.4 parameterized==0.9.0 pluggy==1.5.0 proto-plus==1.25.0 -protobuf==4.25.5 +protobuf==5.29.2 psycopg2-binary==2.9.9 pyarrow==16.1.0 pyarrow-hotfix==0.6 pyasn1==0.6.1 pyasn1_modules==0.4.1 pycparser==2.22 -pydantic==2.9.2 -pydantic_core==2.23.4 +pydantic==2.10.4 +pydantic_core==2.27.2 pydot==1.4.2 PyHamcrest==2.1.0 pymongo==4.10.1 @@ -140,31 +140,31 @@ python-dateutil==2.9.0.post0 python-snappy==0.7.3 pytz==2024.2 PyYAML==6.0.2 -redis==5.2.0 +redis==5.2.1 referencing==0.35.1 regex==2024.11.6 requests==2.32.3 requests-mock==1.12.1 -rpds-py==0.21.0 +rpds-py==0.22.3 rsa==4.9 -scikit-learn==1.5.2 +scikit-learn==1.6.0 scipy==1.14.1 SecretStorage==3.3.3 shapely==2.0.6 -six==1.16.0 +six==1.17.0 sortedcontainers==2.4.0 soupsieve==2.6 SQLAlchemy==2.0.36 -sqlparse==0.5.2 +sqlparse==0.5.3 tenacity==8.5.0 testcontainers==3.7.1 threadpoolctl==3.5.0 -tomli==2.1.0 -tqdm==4.67.0 +tomli==2.2.1 +tqdm==4.67.1 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 -urllib3==2.2.3 -wrapt==1.16.0 +urllib3==2.3.0 +wrapt==1.17.0 zipp==3.21.0 zstandard==0.23.0 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 93f579b14dd8..3420d0444466 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -22,20 +22,20 @@ # Reach out to a committer if you need help. annotated-types==0.7.0 -attrs==24.2.0 +attrs==24.3.0 backports.tarfile==1.2.0 beautifulsoup4==4.12.3 bs4==0.0.2 build==1.2.2.post1 cachetools==5.5.0 -certifi==2024.8.30 +certifi==2024.12.14 cffi==1.17.1 -charset-normalizer==3.4.0 -click==8.1.7 +charset-normalizer==3.4.1 +click==8.1.8 cloudpickle==2.2.1 -cramjam==2.9.0 +cramjam==2.9.1 crcmod==1.7 -cryptography==43.0.3 +cryptography==44.0.0 Cython==3.0.11 Deprecated==1.2.15 deprecation==2.1.0 @@ -45,32 +45,32 @@ docker==7.1.0 docopt==0.6.2 docstring_parser==0.16 execnet==2.1.1 -fastavro==1.9.7 +fastavro==1.10.0 fasteners==0.19 freezegun==1.5.1 future==1.0.0 -google-api-core==2.23.0 -google-api-python-client==2.153.0 +google-api-core==2.24.0 +google-api-python-client==2.156.0 google-apitools==0.5.31 -google-auth==2.36.0 +google-auth==2.37.0 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.72.0 +google-cloud-aiplatform==1.75.0 google-cloud-bigquery==3.27.0 google-cloud-bigquery-storage==2.27.0 google-cloud-bigtable==2.27.0 google-cloud-core==2.4.1 -google-cloud-datastore==2.20.1 -google-cloud-dlp==3.25.1 -google-cloud-language==2.15.1 +google-cloud-datastore==2.20.2 +google-cloud-dlp==3.26.0 +google-cloud-language==2.16.0 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.27.1 google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.14 -google-cloud-resource-manager==1.13.1 -google-cloud-spanner==3.50.1 -google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.14.1 -google-cloud-vision==3.8.1 +google-cloud-recommendations-ai==0.10.15 +google-cloud-resource-manager==1.14.0 +google-cloud-spanner==3.51.0 +google-cloud-storage==2.19.0 +google-cloud-videointelligence==2.15.0 +google-cloud-vision==3.9.0 google-crc32c==1.6.0 google-resumable-media==2.7.2 googleapis-common-protos==1.66.0 @@ -78,11 +78,11 @@ greenlet==3.1.1 grpc-google-iam-v1==0.13.1 grpc-interceptor==0.15.4 grpcio==1.65.5 -grpcio-status==1.62.3 +grpcio-status==1.65.5 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.119.1 +hypothesis==6.123.2 idna==3.10 importlib_metadata==8.5.0 iniconfig==2.0.0 @@ -90,12 +90,12 @@ jaraco.classes==3.4.0 jaraco.context==6.0.1 jaraco.functools==4.1.0 jeepney==0.8.0 -Jinja2==3.1.4 +Jinja2==3.1.5 joblib==1.4.2 jsonpickle==3.4.2 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 -keyring==25.5.0 +keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 mmh3==5.0.1 @@ -106,25 +106,25 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -opentelemetry-api==1.28.1 -opentelemetry-sdk==1.28.1 -opentelemetry-semantic-conventions==0.49b1 -orjson==3.10.11 +opentelemetry-api==1.29.0 +opentelemetry-sdk==1.29.0 +opentelemetry-semantic-conventions==0.50b0 +orjson==3.10.12 overrides==7.7.0 packaging==24.2 pandas==2.1.4 parameterized==0.9.0 pluggy==1.5.0 proto-plus==1.25.0 -protobuf==4.25.5 +protobuf==5.29.2 psycopg2-binary==2.9.9 pyarrow==16.1.0 pyarrow-hotfix==0.6 pyasn1==0.6.1 pyasn1_modules==0.4.1 pycparser==2.22 -pydantic==2.9.2 -pydantic_core==2.23.4 +pydantic==2.10.4 +pydantic_core==2.27.2 pydot==1.4.2 PyHamcrest==2.1.0 pymongo==4.10.1 @@ -138,30 +138,30 @@ python-dateutil==2.9.0.post0 python-snappy==0.7.3 pytz==2024.2 PyYAML==6.0.2 -redis==5.2.0 +redis==5.2.1 referencing==0.35.1 regex==2024.11.6 requests==2.32.3 requests-mock==1.12.1 -rpds-py==0.21.0 +rpds-py==0.22.3 rsa==4.9 -scikit-learn==1.5.2 +scikit-learn==1.6.0 scipy==1.14.1 SecretStorage==3.3.3 shapely==2.0.6 -six==1.16.0 +six==1.17.0 sortedcontainers==2.4.0 soupsieve==2.6 SQLAlchemy==2.0.36 -sqlparse==0.5.2 +sqlparse==0.5.3 tenacity==8.5.0 testcontainers==3.7.1 threadpoolctl==3.5.0 -tqdm==4.67.0 +tqdm==4.67.1 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 -urllib3==2.2.3 -wrapt==1.16.0 +urllib3==2.3.0 +wrapt==1.17.0 zipp==3.21.0 zstandard==0.23.0 diff --git a/sdks/python/container/py312/base_image_requirements.txt b/sdks/python/container/py312/base_image_requirements.txt index 069005318cdb..cfc9e96087d0 100644 --- a/sdks/python/container/py312/base_image_requirements.txt +++ b/sdks/python/container/py312/base_image_requirements.txt @@ -22,19 +22,19 @@ # Reach out to a committer if you need help. annotated-types==0.7.0 -attrs==24.2.0 +attrs==24.3.0 beautifulsoup4==4.12.3 bs4==0.0.2 build==1.2.2.post1 cachetools==5.5.0 -certifi==2024.8.30 +certifi==2024.12.14 cffi==1.17.1 -charset-normalizer==3.4.0 -click==8.1.7 +charset-normalizer==3.4.1 +click==8.1.8 cloudpickle==2.2.1 -cramjam==2.9.0 +cramjam==2.9.1 crcmod==1.7 -cryptography==43.0.3 +cryptography==44.0.0 Cython==3.0.11 Deprecated==1.2.15 deprecation==2.1.0 @@ -44,32 +44,32 @@ docker==7.1.0 docopt==0.6.2 docstring_parser==0.16 execnet==2.1.1 -fastavro==1.9.7 +fastavro==1.10.0 fasteners==0.19 freezegun==1.5.1 future==1.0.0 -google-api-core==2.23.0 -google-api-python-client==2.153.0 +google-api-core==2.24.0 +google-api-python-client==2.156.0 google-apitools==0.5.31 -google-auth==2.36.0 +google-auth==2.37.0 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.72.0 +google-cloud-aiplatform==1.75.0 google-cloud-bigquery==3.27.0 google-cloud-bigquery-storage==2.27.0 google-cloud-bigtable==2.27.0 google-cloud-core==2.4.1 -google-cloud-datastore==2.20.1 -google-cloud-dlp==3.25.1 -google-cloud-language==2.15.1 +google-cloud-datastore==2.20.2 +google-cloud-dlp==3.26.0 +google-cloud-language==2.16.0 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.27.1 google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.14 -google-cloud-resource-manager==1.13.1 -google-cloud-spanner==3.50.1 -google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.14.1 -google-cloud-vision==3.8.1 +google-cloud-recommendations-ai==0.10.15 +google-cloud-resource-manager==1.14.0 +google-cloud-spanner==3.51.0 +google-cloud-storage==2.19.0 +google-cloud-videointelligence==2.15.0 +google-cloud-vision==3.9.0 google-crc32c==1.6.0 google-resumable-media==2.7.2 googleapis-common-protos==1.66.0 @@ -77,11 +77,11 @@ greenlet==3.1.1 grpc-google-iam-v1==0.13.1 grpc-interceptor==0.15.4 grpcio==1.65.5 -grpcio-status==1.62.3 +grpcio-status==1.65.5 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.119.1 +hypothesis==6.123.2 idna==3.10 importlib_metadata==8.5.0 iniconfig==2.0.0 @@ -89,12 +89,12 @@ jaraco.classes==3.4.0 jaraco.context==6.0.1 jaraco.functools==4.1.0 jeepney==0.8.0 -Jinja2==3.1.4 +Jinja2==3.1.5 joblib==1.4.2 jsonpickle==3.4.2 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 -keyring==25.5.0 +keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 mmh3==5.0.1 @@ -105,25 +105,25 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -opentelemetry-api==1.28.1 -opentelemetry-sdk==1.28.1 -opentelemetry-semantic-conventions==0.49b1 -orjson==3.10.11 +opentelemetry-api==1.29.0 +opentelemetry-sdk==1.29.0 +opentelemetry-semantic-conventions==0.50b0 +orjson==3.10.12 overrides==7.7.0 packaging==24.2 pandas==2.1.4 parameterized==0.9.0 pluggy==1.5.0 proto-plus==1.25.0 -protobuf==4.25.5 +protobuf==5.29.2 psycopg2-binary==2.9.9 pyarrow==16.1.0 pyarrow-hotfix==0.6 pyasn1==0.6.1 pyasn1_modules==0.4.1 pycparser==2.22 -pydantic==2.9.2 -pydantic_core==2.23.4 +pydantic==2.10.4 +pydantic_core==2.27.2 pydot==1.4.2 PyHamcrest==2.1.0 pymongo==4.10.1 @@ -137,32 +137,32 @@ python-dateutil==2.9.0.post0 python-snappy==0.7.3 pytz==2024.2 PyYAML==6.0.2 -redis==5.2.0 +redis==5.2.1 referencing==0.35.1 regex==2024.11.6 requests==2.32.3 requests-mock==1.12.1 -rpds-py==0.21.0 +rpds-py==0.22.3 rsa==4.9 -scikit-learn==1.5.2 +scikit-learn==1.6.0 scipy==1.14.1 SecretStorage==3.3.3 -setuptools==75.5.0 +setuptools==75.6.0 shapely==2.0.6 -six==1.16.0 +six==1.17.0 sortedcontainers==2.4.0 soupsieve==2.6 SQLAlchemy==2.0.36 -sqlparse==0.5.2 +sqlparse==0.5.3 tenacity==8.5.0 testcontainers==3.7.1 threadpoolctl==3.5.0 -tqdm==4.67.0 +tqdm==4.67.1 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 -urllib3==2.2.3 -wheel==0.45.0 -wrapt==1.16.0 +urllib3==2.3.0 +wheel==0.45.1 +wrapt==1.17.0 zipp==3.21.0 zstandard==0.23.0 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 22ab0a2fbcf8..3fb495e30f1e 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -23,20 +23,20 @@ annotated-types==0.7.0 async-timeout==5.0.1 -attrs==24.2.0 +attrs==24.3.0 backports.tarfile==1.2.0 beautifulsoup4==4.12.3 bs4==0.0.2 build==1.2.2.post1 cachetools==5.5.0 -certifi==2024.8.30 +certifi==2024.12.14 cffi==1.17.1 -charset-normalizer==3.4.0 -click==8.1.7 +charset-normalizer==3.4.1 +click==8.1.8 cloudpickle==2.2.1 -cramjam==2.9.0 +cramjam==2.9.1 crcmod==1.7 -cryptography==43.0.3 +cryptography==44.0.0 Cython==3.0.11 Deprecated==1.2.15 deprecation==2.1.0 @@ -47,32 +47,32 @@ docopt==0.6.2 docstring_parser==0.16 exceptiongroup==1.2.2 execnet==2.1.1 -fastavro==1.9.7 +fastavro==1.10.0 fasteners==0.19 freezegun==1.5.1 future==1.0.0 -google-api-core==2.23.0 -google-api-python-client==2.153.0 +google-api-core==2.24.0 +google-api-python-client==2.156.0 google-apitools==0.5.31 -google-auth==2.36.0 +google-auth==2.37.0 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.72.0 +google-cloud-aiplatform==1.75.0 google-cloud-bigquery==3.27.0 google-cloud-bigquery-storage==2.27.0 google-cloud-bigtable==2.27.0 google-cloud-core==2.4.1 -google-cloud-datastore==2.20.1 -google-cloud-dlp==3.25.1 -google-cloud-language==2.15.1 +google-cloud-datastore==2.20.2 +google-cloud-dlp==3.26.0 +google-cloud-language==2.16.0 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.27.1 google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.14 -google-cloud-resource-manager==1.13.1 -google-cloud-spanner==3.50.1 -google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.14.1 -google-cloud-vision==3.8.1 +google-cloud-recommendations-ai==0.10.15 +google-cloud-resource-manager==1.14.0 +google-cloud-spanner==3.51.0 +google-cloud-storage==2.19.0 +google-cloud-videointelligence==2.15.0 +google-cloud-vision==3.9.0 google-crc32c==1.6.0 google-resumable-media==2.7.2 googleapis-common-protos==1.66.0 @@ -80,11 +80,11 @@ greenlet==3.1.1 grpc-google-iam-v1==0.13.1 grpc-interceptor==0.15.4 grpcio==1.65.5 -grpcio-status==1.62.3 +grpcio-status==1.65.5 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.119.1 +hypothesis==6.123.2 idna==3.10 importlib_metadata==8.5.0 iniconfig==2.0.0 @@ -92,12 +92,12 @@ jaraco.classes==3.4.0 jaraco.context==6.0.1 jaraco.functools==4.1.0 jeepney==0.8.0 -Jinja2==3.1.4 +Jinja2==3.1.5 joblib==1.4.2 jsonpickle==3.4.2 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 -keyring==25.5.0 +keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 mmh3==5.0.1 @@ -108,25 +108,25 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -opentelemetry-api==1.28.1 -opentelemetry-sdk==1.28.1 -opentelemetry-semantic-conventions==0.49b1 -orjson==3.10.11 +opentelemetry-api==1.29.0 +opentelemetry-sdk==1.29.0 +opentelemetry-semantic-conventions==0.50b0 +orjson==3.10.12 overrides==7.7.0 packaging==24.2 pandas==2.1.4 parameterized==0.9.0 pluggy==1.5.0 proto-plus==1.25.0 -protobuf==4.25.5 +protobuf==5.29.2 psycopg2-binary==2.9.9 pyarrow==16.1.0 pyarrow-hotfix==0.6 pyasn1==0.6.1 pyasn1_modules==0.4.1 pycparser==2.22 -pydantic==2.9.2 -pydantic_core==2.23.4 +pydantic==2.10.4 +pydantic_core==2.27.2 pydot==1.4.2 PyHamcrest==2.1.0 pymongo==4.10.1 @@ -140,31 +140,31 @@ python-dateutil==2.9.0.post0 python-snappy==0.7.3 pytz==2024.2 PyYAML==6.0.2 -redis==5.2.0 +redis==5.2.1 referencing==0.35.1 regex==2024.11.6 requests==2.32.3 requests-mock==1.12.1 -rpds-py==0.21.0 +rpds-py==0.22.3 rsa==4.9 -scikit-learn==1.5.2 +scikit-learn==1.6.0 scipy==1.13.1 SecretStorage==3.3.3 shapely==2.0.6 -six==1.16.0 +six==1.17.0 sortedcontainers==2.4.0 soupsieve==2.6 SQLAlchemy==2.0.36 -sqlparse==0.5.2 +sqlparse==0.5.3 tenacity==8.5.0 testcontainers==3.7.1 threadpoolctl==3.5.0 -tomli==2.1.0 -tqdm==4.67.0 +tomli==2.2.1 +tqdm==4.67.1 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 -urllib3==2.2.3 -wrapt==1.16.0 +urllib3==2.3.0 +wrapt==1.17.0 zipp==3.21.0 zstandard==0.23.0 diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 4eb827297019..8000c24f28aa 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -26,7 +26,7 @@ requires = [ # Avoid https://github.com/pypa/virtualenv/issues/2006 "distlib==0.3.7", # Numpy headers - "numpy>=1.14.3,<2.2.0", # Update setup.py as well. + "numpy>=1.14.3,<2.3.0", # Update setup.py as well. # having cython here will create wheels that are platform dependent. "cython>=3.0,<4", ## deps for generating external transform wrappers: diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 53c7a532e706..da9e0b2e7477 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -361,7 +361,7 @@ def get_portability_package_data(): 'jsonpickle>=3.0.0,<4.0.0', # numpy can have breaking changes in minor versions. # Use a strict upper bound. - 'numpy>=1.14.3,<2.2.0', # Update pyproject.toml as well. + 'numpy>=1.14.3,<2.3.0', # Update pyproject.toml as well. 'objsize>=0.6.1,<0.8.0', 'packaging>=22.0', 'pymongo>=3.8.0,<5.0.0', diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle index be87be749862..2d216a01f320 100644 --- a/sdks/python/test-suites/portable/common.gradle +++ b/sdks/python/test-suites/portable/common.gradle @@ -265,7 +265,8 @@ project.tasks.register("postCommitPy${pythonVersionSuffix}") { ':runners:spark:3:job-server:shadowJar', 'portableLocalRunnerJuliaSetWithSetupPy', 'portableWordCountSparkRunnerBatch', - 'portableLocalRunnerTestWithRequirementsFile'] + 'portableLocalRunnerTestWithRequirementsFile' + ] } project.tasks.register("flinkExamples") { @@ -376,7 +377,7 @@ project.tasks.register("postCommitPy${pythonVersionSuffix}IT") { ':sdks:java:testing:kafka-service:buildTestKafkaServiceJar', ':sdks:java:io:expansion-service:shadowJar', ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar', - ':sdks:java:io:kinesis:expansion-service:shadowJar', + ':sdks:java:io:amazon-web-services2:expansion-service:shadowJar', ':sdks:java:extensions:schemaio-expansion-service:shadowJar', ':sdks:java:io:debezium:expansion-service:shadowJar' ] @@ -426,7 +427,7 @@ project.tasks.register("xlangSpannerIOIT") { ":sdks:java:container:${currentJavaVersion}:docker", ':sdks:java:io:expansion-service:shadowJar', ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar', - ':sdks:java:io:kinesis:expansion-service:shadowJar', + ':sdks:java:io:amazon-web-services2:expansion-service:shadowJar', ':sdks:java:extensions:schemaio-expansion-service:shadowJar', ':sdks:java:io:debezium:expansion-service:shadowJar' ] diff --git a/sdks/typescript/package.json b/sdks/typescript/package.json index 9ccfcaa663d1..3ed0a0e427f4 100644 --- a/sdks/typescript/package.json +++ b/sdks/typescript/package.json @@ -1,6 +1,6 @@ { "name": "apache-beam", - "version": "2.62.0-SNAPSHOT", + "version": "2.63.0-SNAPSHOT", "devDependencies": { "@google-cloud/bigquery": "^5.12.0", "@types/mocha": "^9.0.0", diff --git a/settings.gradle.kts b/settings.gradle.kts index d90bb3fb5b82..6cce1ec0a506 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -204,8 +204,8 @@ include(":sdks:java:extensions:timeseries") include(":sdks:java:extensions:zetasketch") include(":sdks:java:harness") include(":sdks:java:harness:jmh") -include(":sdks:java:io:amazon-web-services") include(":sdks:java:io:amazon-web-services2") +include(":sdks:java:io:amazon-web-services2:expansion-service") include(":sdks:java:io:amqp") include(":sdks:java:io:azure") include(":sdks:java:io:azure-cosmos") @@ -238,8 +238,6 @@ include(":sdks:java:io:jms") include(":sdks:java:io:json") include(":sdks:java:io:kafka") include(":sdks:java:io:kafka:upgrade") -include(":sdks:java:io:kinesis") -include(":sdks:java:io:kinesis:expansion-service") include(":sdks:java:io:kudu") include(":sdks:java:io:mongodb") include(":sdks:java:io:mqtt") @@ -357,5 +355,3 @@ include("sdks:java:extensions:combiners") findProject(":sdks:java:extensions:combiners")?.name = "combiners" include("sdks:java:io:iceberg:hive") findProject(":sdks:java:io:iceberg:hive")?.name = "hive" -include("sdks:java:io:iceberg:hive:exec") -findProject(":sdks:java:io:iceberg:hive:exec")?.name = "exec" diff --git a/start-build-env.sh b/start-build-env.sh index b788146eb988..0f23f32a269c 100755 --- a/start-build-env.sh +++ b/start-build-env.sh @@ -91,7 +91,7 @@ RUN echo "${USER_NAME} ALL=NOPASSWD: ALL" > "/etc/sudoers.d/beam-build-${USER_ID ENV HOME "${DOCKER_HOME_DIR}" ENV GOPATH ${DOCKER_HOME_DIR}/beam/sdks/go/examples/.gogradle/project_gopath # This next command still runs as root causing the ~/.cache/go-build to be owned by root -RUN go get github.com/linkedin/goavro/v2 +RUN go mod init beam-build-${USER_ID} && go get github.com/linkedin/goavro/v2 RUN chown -R ${USER_NAME}:${GROUP_ID} ${DOCKER_HOME_DIR}/.cache UserSpecificDocker diff --git a/website/www/site/content/en/documentation/sdks/python-custom-multi-language-pipelines-guide.md b/website/www/site/content/en/documentation/sdks/python-custom-multi-language-pipelines-guide.md new file mode 100644 index 000000000000..60523cbb3b2a --- /dev/null +++ b/website/www/site/content/en/documentation/sdks/python-custom-multi-language-pipelines-guide.md @@ -0,0 +1,307 @@ +--- +type: languages +title: "Python custom multi-language pipelines guide" +--- + + +# Python custom multi-language pipelines guide + +Apache Beam's powerful model enables the development of scalable, resilient, and production-ready transforms, but the process often requires significant time and effort. + +With SDKs available in multiple languages (Java, Python, Golang, YAML, etc.), creating and maintaining transforms for each language becomes a challenge, particularly for IOs. Developers must navigate different APIs, address unique quirks, and manage ongoing maintenance—such as updates, new features, and documentation—while ensuring consistent behavior across SDKs. This results in redundant work, as the same functionality is implemented repeatedly for each language (M x N effort, where M is the number of SDKs and N is the number of transforms). + +To streamline this process, Beam’s portability framework enables the use of portable transforms that can be shared across languages. This reduces duplication, allowing developers to focus on maintaining only N transforms. Pipelines combining [portable transforms](#portable-transform) from other SDK(s) are known as [“multi-language” pipelines](../programming-guide.md#13-multi-language-pipelines-multi-language-pipelines). + +The SchemaTransform framework represents the latest advancement in enhancing this multi-language capability. + +The following jumps straight into the guide. Check out the [appendix](#appendix) section below for some of the terminology used here. For a runnable example, check out this [page](python-multi-language-pipelines-2.md). + +## Create a Java SchemaTransform + +For better readability, use [**TypedSchemaTransformProvider**](https://beam.apache.org/releases/javadoc/current/index.html?org/apache/beam/sdk/schemas/transforms/TypedSchemaTransformProvider.html), a [SchemaTransformProvider](#schematransformprovider) parameterized on a custom configuration type `T`. TypedSchemaTransformProvider will take care of converting the custom type definition to a Beam [Schema](../basics.md#schema), and converting an instance to a Beam Row. + +```java +TypedSchemaTransformProvider extends SchemaTransformProvider { + String identifier(); + + SchemaTransform from(T configuration); +} +``` + +### Implement a configuration + +First, set up a Beam Schema-compatible configuration. This will be used to construct the transform. AutoValue types are encouraged for readability. Adding the appropriate `@DefaultSchema` annotation will help Beam do the conversions mentioned above. + +```java +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract static class MyConfiguration { + public static Builder builder() { + return new AutoValue_MyConfiguration.Builder(); + } + @SchemaFieldDescription("Description of what foo does...") + public abstract String getFoo(); + + @SchemaFieldDescription("Description of what bar does...") + public abstract Integer getBar(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setFoo(String foo); + + public abstract Builder setBar(Integer bar); + + public abstract MyConfiguration build(); + } +} +``` + +This configuration is surfaced to foreign SDKs. For example, when using this transform in Python, use the following format: + +```python +with beam.Pipeline() as p: + (p + | Create([...]) + | MySchemaTransform(foo="abc", bar=123) +``` + +When using this transform in YAML, use the following format: + +```yaml +pipeline: + transforms: + - type: Create + ... + - type: MySchemaTransform + config: + foo: "abc" + bar: 123 +``` + +### Implement a TypedSchemaTransformProvider +Next, implement the `TypedSchemaTransformProvider`. The following two methods are required: + +- `identifier`: Returns a unique identifier for this transform. The [Beam standard](../programming-guide.md#1314-defining-a-urn) follows this structure: `:::`. +- `from`: Builds the transform using a provided configuration. + +An [expansion service](#expansion-service) uses these methods to find and build the transform. The `@AutoService(SchemaTransformProvider.class)` annotation is also required to ensure this provider is recognized by the expansion service. + +```java +@AutoService(SchemaTransformProvider.class) +public class MyProvider extends TypedSchemaTransformProvider { + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:my_transform:v1"; + } + + @Override + protected SchemaTransform from(MyConfiguration configuration) { + return new MySchemaTransform(configuration); + } + + private static class MySchemaTransform extends SchemaTransform { + private final MyConfiguration config; + MySchemaTransform(MyConfiguration configuration) { + this.config = configuration; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PCollection inputRows = input.get("input"); + PCollection outputRows = inputRows.apply( + new MyJavaTransform(config.getFoo(), config.getBar())); + + return PCollectionRowTuple.of("output", outputRows); + } + } +} +``` + +#### Additional metadata (optional) +The following optional methods can help provide relevant metadata: +- `description`: Provide a human-readable description for the transform. Remote SDKs can use this text to generate documentation. +- `inputCollectionNames`: Provide PCollection tags that this transform expects to take in. +- `outputCollectionNames`: Provide PCollection tags this transform expects to produce. + +```java + @Override + public String description() { + return "This transform does this and that..."; + } + + @Override + public List inputCollectionNames() { + return Arrays.asList("input_1", "input_2"); + } + + @Override + public List outputCollectionNames() { + return Collections.singletonList("output"); + } +``` + +## Build an expansion service that contains the transform + +Use an expansion service to make the transform available to foreign SDKs. + +First, build a shaded JAR file that includes: +1. the transform, +2. the [**ExpansionService artifact**](https://central.sonatype.com/artifact/org.apache.beam/beam-sdks-java-expansion-service), +3. and some additional dependencies. + +### Gradle build file +```groovy +plugins { + id 'com.github.johnrengelman.shadow' version '8.1.1' + id 'application' +} + +mainClassName = "org.apache.beam.sdk.expansion.service.ExpansionService" + +dependencies { + // Dependencies for your transform + ... + + // Beam's expansion service + runtimeOnly "org.apache.beam:beam-sdks-java-expansion-service:$beamVersion" + // AutoService annotation for our SchemaTransform provider + compileOnly "com.google.auto.service:auto-service-annotations:1.0.1" + annotationProcessor "com.google.auto.service:auto-service:1.0.1" + // AutoValue annotation for our configuration object + annotationProcessor "com.google.auto.value:auto-value:1.9" +} +``` + +Next, run the shaded JAR file, and provide a port to host the service. A list of available SchemaTransformProviders will be displayed. + +```shell +$ java -jar path/to/my-expansion-service.jar 12345 + +Starting expansion service at localhost:12345 + +Registered transforms: + ... +Registered SchemaTransformProviders: + beam:schematransform:org.apache.beam:my_transform:v1 +``` + +The transform is discoverable at `localhost:12345`. Foreign SDKs can now discover and add it to their pipelines. The next section demonstrates how to do this with a Python pipeline. + +## Use the portable transform in a Python pipeline + +The Python SDK’s [**ExternalTransformProvider**](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.external_transform_provider.html#apache_beam.transforms.external_transform_provider.ExternalTransformProvider) +can dynamically generate wrappers for portable transforms. + +```python +from apache_beam.transforms.external_transform_provider import ExternalTransformProvider +``` + +### Connect to an expansion service +First, connect to an expansion service that contains the transform. This section demonstrates two methods of connecting to the expansion service. + +#### Connect to an already running service + +If your expansion service JAR file is already running, pass in the address: + +```python +provider = ExternalTransformProvider("localhost:12345") +``` + +#### Start a service based on a Java JAR file + +If the service lives in a JAR file but isn’t currently running, use Beam utilities to run the service in a subprocess: + +```python +from apache_beam.transforms.external import JavaJarExpansionService + +provider = ExternalTransformProvider( + JavaJarExpansionService("path/to/my-expansion-service.jar")) +``` + +You can also provide a list of services: + +```python +provider = ExternalTransformProvider([ + "localhost:12345", + JavaJarExpansionService("path/to/my-expansion-service.jar"), + JavaJarExpansionService("path/to/another-expansion-service.jar")]) +``` + +When initialized, the `ExternalTransformProvider` connects to the expansion service(s), retrieves all portable transforms, and generates a Pythonic wrapper for each one. + +### Retrieve and use the transform + +Retrieve the transform using its unique identifier and use it in your multi-language pipeline: + +```python +identifier = "beam:schematransform:org.apache.beam:my_transform:v1" +MyTransform = provider.get_urn(identifier) + +with beam.Pipeline() as p: + p | beam.Create(...) | MyTransform(foo="abc", bar=123) +``` + + +### Inspect the transform's metadata +You can learn more about a portable transform’s configuration by inspecting its metadata: + +```python +import inspect + +inspect.getdoc(MyTransform) +# Output: "This transform does this and that..." + +inspect.signature(MyTransform) +# Output: (foo: "str: Description of what foo does...", +# bar: "int: Description of what bar does....") +``` + +This metadata is generated directly from the provider's implementation. The class documentation is generated from the [optional **description** method](#additional-metadata). The signature information is generated from the `@SchemaFieldDescription` annotations in the [configuration object](#implement-a-configuration). + +## Appendix + +### Portable transform + +Also known as a [cross-language transform](../glossary.md#cross-language-transforms): a transform that is made available to other SDKs (i.e. other languages) via an expansion service. Such a transform must offer a way to be constructed using language-agnostic parameter types. + +### Expansion Service + +A container that can hold multiple portable transforms. During pipeline expansion, this service will +- Look up the transform in its internal registry +- Build the transform in its native language using the provided configuration +- Expand the transform – i.e. construct the transform’s sub-graph to be inserted in the pipeline +- Establish a gRPC communication channel with the runner to exchange data and signals during pipeline execution. + +### SchemaTransform + +A transform that takes and produces PCollections of Beam Rows with a predefined Schema, i.e.: + +```java +SchemaTransform extends PTransform {} +``` + +### SchemaTransformProvider + +Produces a SchemaTransform using a provided configuration. An expansion service uses this interface to identify and build the transform for foreign SDKs. + +```java +SchemaTransformProvider { + String identifier(); + + SchemaTransform from(Row configuration); + + Schema configurationSchema(); +} +``` \ No newline at end of file diff --git a/website/www/site/content/en/documentation/sdks/yaml.md b/website/www/site/content/en/documentation/sdks/yaml.md index 3559a18076ba..7459d09c3ccc 100644 --- a/website/www/site/content/en/documentation/sdks/yaml.md +++ b/website/www/site/content/en/documentation/sdks/yaml.md @@ -678,6 +678,19 @@ providers: MyCustomTransform: "pkg.subpkg.PTransformClassOrCallable" ``` +One can additionally reference an external listings of providers as follows + +``` +providers: + - include: "file:///path/to/local/providers.yaml" + - include: "gs://path/to/remote/providers.yaml" + - include: "https://example.com/hosted/providers.yaml" + ... +``` + +where `providers.yaml` is simply a yaml file containing a list of providers +in the same format as those inlined in this providers block. + ## Pipeline options [Pipeline options](https://beam.apache.org/documentation/programming-guide/#configuring-pipeline-options)