diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java index 6441524fc847..92c13ea30fe7 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java @@ -38,6 +38,7 @@ import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord; import com.google.cloud.bigtable.data.v2.models.KeyOffset; import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; @@ -610,6 +611,23 @@ public Read withRowFilter(RowFilter filter) { return withRowFilter(StaticValueProvider.of(filter)); } + /** + * Returns a new {@link BigtableIO.Read} that will filter the rows read from Cloud Bigtable + * using the given row filter encoded with {@link RowUtils#encodeRowFilter(RowFilter)}. If + * {@link #withRowFilter(RowFilter)} is also set, it'll use the row filter specified in {@link + * #withRowFilter(RowFilter)}. + * + *

Does not modify this object. + */ + public Read withEncodedRowFilter(ValueProvider filter) { + checkArgumentNotNull(filter, "filter can not be null"); + BigtableReadOptions bigtableReadOptions = getBigtableReadOptions(); + return toBuilder() + .setBigtableReadOptions( + bigtableReadOptions.toBuilder().setEncodedRowFilter(filter).build()) + .build(); + } + /** * Returns a new {@link BigtableIO.Read} that will break up read requests into smaller batches. * This function will switch the base BigtableIO.Reader class to using the SegmentReader. If @@ -1939,7 +1957,19 @@ public List getRanges() { public @Nullable RowFilter getRowFilter() { ValueProvider rowFilter = readOptions.getRowFilter(); - return rowFilter != null && rowFilter.isAccessible() ? rowFilter.get() : null; + if (rowFilter != null && rowFilter.isAccessible()) { + return rowFilter.get(); + } + ValueProvider encoded = readOptions.getEncodedRowFilter(); + if (encoded != null && encoded.isAccessible()) { + String filterString = encoded.get(); + try { + return RowUtils.decodeRowFilter(filterString); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Failed to decode row filter string", e); + } + } + return null; } public @Nullable Integer getMaxBufferElementCount() { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadOptions.java index 46834cc9756f..d3bdd299cfe0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadOptions.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadOptions.java @@ -43,6 +43,9 @@ abstract class BigtableReadOptions implements Serializable { /** Returns the row filter to use. */ abstract @Nullable ValueProvider getRowFilter(); + /** Returns the row filter string encoded with {@link RowUtils#encodeRowFilter(RowFilter)}. */ + abstract @Nullable ValueProvider getEncodedRowFilter(); + /** Returns the key ranges to read. */ abstract @Nullable ValueProvider> getKeyRanges(); @@ -73,6 +76,8 @@ abstract static class Builder { abstract Builder setRowFilter(ValueProvider rowFilter); + abstract Builder setEncodedRowFilter(ValueProvider serializedRowFilter); + abstract Builder setMaxBufferElementCount(@Nullable Integer maxBufferElementCount); abstract Builder setKeyRanges(ValueProvider> keyRanges); @@ -110,6 +115,9 @@ void populateDisplayData(DisplayData.Builder builder) { builder .addIfNotNull(DisplayData.item("tableId", getTableId()).withLabel("Bigtable Table Id")) .addIfNotNull(DisplayData.item("rowFilter", getRowFilter()).withLabel("Row Filter")) + .addIfNotNull( + DisplayData.item("encodedRowFilter", getEncodedRowFilter()) + .withLabel("Encoded Row Filter")) .addIfNotNull(DisplayData.item("keyRanges", getKeyRanges()).withLabel("Key Ranges")) .addIfNotNull( DisplayData.item("attemptTimeout", getAttemptTimeout()) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/RowUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/RowUtils.java index a79a432cb89f..952482605576 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/RowUtils.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/RowUtils.java @@ -17,7 +17,10 @@ */ package org.apache.beam.sdk.io.gcp.bigtable; +import com.google.bigtable.v2.RowFilter; import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import java.util.Base64; public class RowUtils { public static final String KEY = "key"; @@ -33,4 +36,15 @@ public static ByteString byteString(byte[] bytes) { public static ByteString byteStringUtf8(String value) { return ByteString.copyFromUtf8(value); } + + /** Encode a row filter with Base64 encoding. */ + public static String encodeRowFilter(RowFilter filter) { + return Base64.getEncoder().encodeToString(filter.toByteArray()); + } + + /** Decode a base64 encoded row filter string. */ + public static RowFilter decodeRowFilter(String serialized) throws InvalidProtocolBufferException { + byte[] decoded = Base64.getDecoder().decode(serialized); + return RowFilter.parseFrom(decoded); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java index 2065772a9a4f..92c607441cd0 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java @@ -680,6 +680,36 @@ public void testReadingWithRuntimeParameterizedFilter() throws Exception { defaultRead.withTableId(table).withRowFilter(StaticValueProvider.of(filter)), Lists.newArrayList(filteredRows)); } + + @Test + public void testReadingWithEncodedRowFilter() throws Exception { + final String table = "TEST-FILTER-TABLE"; + final int numRows = 1001; + List testRows = makeTableData(table, numRows); + String regex = ".*17.*"; + final KeyMatchesRegex keyPredicate = new KeyMatchesRegex(regex); + Iterable filteredRows = + testRows.stream() + .filter( + input -> { + verifyNotNull(input, "input"); + return keyPredicate.apply(input.getKey()); + }) + .collect(Collectors.toList()); + + RowFilter filter = + RowFilter.newBuilder().setRowKeyRegexFilter(ByteString.copyFromUtf8(regex)).build(); + String serializedFilter = RowUtils.encodeRowFilter(filter); + + service.setupSampleRowKeys(table, 5, 10L); + + runReadTest( + defaultRead + .withTableId(table) + .withEncodedRowFilter(StaticValueProvider.of(serializedFilter)), + Lists.newArrayList(filteredRows)); + } + /** Tests dynamic work rebalancing exhaustively. */ @Test public void testReadingSplitAtFractionExhaustive() throws Exception {