From 1d7d1cd9b25144270c744b9b2ad39cc273841242 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Mon, 26 Jan 2026 00:32:15 +0200 Subject: [PATCH] Set quota project in beam.io.ReadFromBigQuery in Python SDK --- CHANGES.md | 2 +- sdks/python/apache_beam/internal/gcp/auth.py | 41 ++++ .../apache_beam/internal/gcp/auth_test.py | 60 +++++ sdks/python/apache_beam/io/gcp/bigquery.py | 86 +++++++- .../apache_beam/io/gcp/bigquery_test.py | 207 ++++++++++++++++++ .../apache_beam/io/gcp/bigquery_tools.py | 81 ++++++- .../apache_beam/io/gcp/bigquery_tools_test.py | 75 ++++++- .../apache_beam/options/pipeline_options.py | 8 + 8 files changed, 528 insertions(+), 32 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index ff931802addf..b63d04a7f73c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -68,7 +68,7 @@ ## New Features / Improvements -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Added support for setting quota project ID in BigQuery read operations via `--quota_project_id` pipeline option or `quota_project_id` parameter in ReadFromBigQuery transform (Python) ([#37431](https://github.com/apache/beam/issues/37431)). ## Breaking Changes diff --git a/sdks/python/apache_beam/internal/gcp/auth.py b/sdks/python/apache_beam/internal/gcp/auth.py index 168d6aa26939..9a848ea430be 100644 --- a/sdks/python/apache_beam/internal/gcp/auth.py +++ b/sdks/python/apache_beam/internal/gcp/auth.py @@ -82,6 +82,47 @@ def get_service_credentials(pipeline_options): return _Credentials.get_service_credentials(pipeline_options) +def with_quota_project(credentials, quota_project_id): + """For internal use only; no backwards-compatibility guarantees. + + Apply a quota project to credentials if supported. + + The quota project is used to bill API requests to a specific GCP project, + separate from the project that owns the service account or data. + + Args: + credentials: The credentials object (either _ApitoolsCredentialsAdapter + or a google.auth credentials object). + quota_project_id: The GCP project ID to use for quota and billing. + + Returns: + Credentials with the quota project applied, or the original credentials + if quota project is not supported or credentials is None. + """ + if credentials is None or quota_project_id is None: + return credentials + + # Get the underlying google-auth credentials if wrapped + if hasattr(credentials, 'get_google_auth_credentials'): + underlying_creds = credentials.get_google_auth_credentials() + else: + underlying_creds = credentials + + # Apply quota project if supported + if hasattr(underlying_creds, 'with_quota_project'): + new_creds = underlying_creds.with_quota_project(quota_project_id) + # Re-wrap if the original was wrapped + if hasattr(credentials, 'get_google_auth_credentials'): + return _ApitoolsCredentialsAdapter(new_creds) + return new_creds + + _LOGGER.warning( + 'Credentials of type %s do not support quota project. ' + 'The quota_project_id parameter will be ignored.', + type(underlying_creds).__name__) + return credentials + + if _GOOGLE_AUTH_AVAILABLE: class _ApitoolsCredentialsAdapter: diff --git a/sdks/python/apache_beam/internal/gcp/auth_test.py b/sdks/python/apache_beam/internal/gcp/auth_test.py index fe16acc3c089..1f811a16287e 100644 --- a/sdks/python/apache_beam/internal/gcp/auth_test.py +++ b/sdks/python/apache_beam/internal/gcp/auth_test.py @@ -132,5 +132,65 @@ def raise_(scopes=None): auth._LOGGER.removeHandler(loggerHandler) +@unittest.skipIf(gauth is None, 'Google Auth dependencies are not installed') +class WithQuotaProjectTest(unittest.TestCase): + """Tests for with_quota_project function.""" + def test_with_quota_project_returns_credentials_unchanged_when_none(self): + """Test that None credentials are returned unchanged.""" + result = auth.with_quota_project(None, 'my-project') + self.assertIsNone(result) + + def test_with_quota_project_returns_credentials_unchanged_when_no_quota(self): + """Test that credentials are returned unchanged when + quota_project_id is None.""" + mock_creds = mock.MagicMock() + result = auth.with_quota_project(mock_creds, None) + self.assertEqual(result, mock_creds) + mock_creds.with_quota_project.assert_not_called() + + @mock.patch('apache_beam.internal.gcp.auth._ApitoolsCredentialsAdapter') + def test_with_quota_project_applies_quota_to_wrapped_credentials( + self, mock_adapter_class): + """Test that quota project is applied to wrapped credentials.""" + mock_inner_creds = mock.MagicMock() + mock_new_creds = mock.MagicMock() + mock_inner_creds.with_quota_project.return_value = mock_new_creds + + mock_adapter = mock.MagicMock() + mock_adapter.get_google_auth_credentials.return_value = mock_inner_creds + + mock_adapter_instance = mock.MagicMock() + mock_adapter_class.return_value = mock_adapter_instance + + result = auth.with_quota_project(mock_adapter, 'my-billing-project') + + mock_inner_creds.with_quota_project.assert_called_once_with( + 'my-billing-project') + # Result should be a new adapter wrapping the new credentials + mock_adapter_class.assert_called_once_with(mock_new_creds) + self.assertEqual(result, mock_adapter_instance) + + def test_with_quota_project_applies_quota_to_direct_credentials(self): + """Test that quota project is applied to direct credentials.""" + mock_creds = mock.MagicMock(spec=['with_quota_project']) + mock_new_creds = mock.MagicMock() + mock_creds.with_quota_project.return_value = mock_new_creds + + result = auth.with_quota_project(mock_creds, 'my-billing-project') + + mock_creds.with_quota_project.assert_called_once_with('my-billing-project') + self.assertEqual(result, mock_new_creds) + + def test_with_quota_project_returns_original_when_not_supported(self): + """Test that original credentials are returned when + with_quota_project is not supported.""" + # Create a mock without with_quota_project method + mock_creds = mock.MagicMock(spec=[]) + + result = auth.with_quota_project(mock_creds, 'my-billing-project') + + self.assertEqual(result, mock_creds) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 181c891c1b65..8d4a9ec6b948 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -378,6 +378,7 @@ def chain_after(result): import apache_beam as beam from apache_beam import coders from apache_beam import pvalue +from apache_beam.internal.gcp import auth from apache_beam.internal.gcp.json_value import from_json_value from apache_beam.internal.gcp.json_value import to_json_value from apache_beam.io import range_trackers @@ -662,7 +663,8 @@ def __init__( step_name=None, unique_id=None, temp_dataset=None, - query_priority=BigQueryQueryPriority.BATCH): + query_priority=BigQueryQueryPriority.BATCH, + quota_project_id=None): if table is not None and query is not None: raise ValueError( 'Both a BigQuery table and a query were specified.' @@ -696,6 +698,7 @@ def __init__( self.use_json_exports = use_json_exports self.temp_dataset = temp_dataset self.query_priority = query_priority + self.quota_project_id = quota_project_id self._job_name = job_name or 'BQ_EXPORT_JOB' self._step_name = step_name self._source_uuid = unique_id @@ -715,6 +718,7 @@ def display_data(self): 'use_legacy_sql': self.use_legacy_sql, 'bigquery_job_labels': json.dumps(self.bigquery_job_labels), 'export_file_format': export_format, + 'quota_project_id': self._get_quota_project_id() or '', 'launchesBigQueryJobs': DisplayDataItem( True, label="This Dataflow job launches bigquery jobs."), } @@ -782,6 +786,18 @@ def _get_project(self): project = self.project return project + def _get_quota_project_id(self): + """Returns the quota project ID for API calls. + + Prefers the explicit quota_project_id parameter, falls back to + quota_project_id from GoogleCloudOptions. + """ + if self.quota_project_id: + return self.quota_project_id + if self.options is not None: + return self.options.view_as(GoogleCloudOptions).quota_project_id + return None + def _create_source(self, path, coder): if not self.use_json_exports: return create_avro_source(path, validate=self.validate) @@ -799,7 +815,8 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): bq = bigquery_tools.BigQueryWrapper( temp_dataset_id=( self.temp_dataset.datasetId if self.temp_dataset else None), - client=bigquery_tools.BigQueryWrapper._bigquery_client(self.options)) + client=bigquery_tools.BigQueryWrapper._bigquery_client(self.options), + quota_project_id=self._get_quota_project_id()) if self.query is not None: self._setup_temporary_dataset(bq) @@ -932,6 +949,31 @@ def _export_files(self, bq): return table.schema, metadata_list +def _create_bq_storage_client(quota_project_id=None): + """Create a BigQueryReadClient with optional quota project. + + Args: + quota_project_id: Optional GCP project ID to use for quota and billing. + + Returns: + A BigQueryReadClient instance. + """ + if quota_project_id: + try: + import google.auth + from google.auth import exceptions as auth_exceptions + credentials, _ = google.auth.default() + credentials = auth.with_quota_project(credentials, quota_project_id) + return bq_storage.BigQueryReadClient(credentials=credentials) + except (auth_exceptions.DefaultCredentialsError, AttributeError) as e: + _LOGGER.warning( + 'Failed to apply quota project %s to BigQuery Storage client: %s. ' + 'Falling back to default client.', + quota_project_id, + e) + return bq_storage.BigQueryReadClient() + + class _CustomBigQueryStorageSource(BoundedSource): """A base class for BoundedSource implementations which read from BigQuery using the BigQuery Storage API. @@ -989,7 +1031,8 @@ def __init__( temp_dataset: Optional[DatasetReference] = None, temp_table: Optional[TableReference] = None, use_native_datetime: Optional[bool] = False, - timeout: Optional[float] = None): + timeout: Optional[float] = None, + quota_project_id: Optional[str] = None): if table is not None and query is not None: raise ValueError( @@ -1028,6 +1071,7 @@ def __init__( self._job_name = job_name or 'BQ_DIRECT_READ_JOB' self._step_name = step_name self._source_uuid = unique_id + self.quota_project_id = quota_project_id def _get_project(self): """Returns the project that queries and exports will be billed to.""" @@ -1039,6 +1083,18 @@ def _get_project(self): return project return self.project + def _get_quota_project_id(self): + """Returns the quota project ID for API calls. + + Prefers the explicit quota_project_id parameter, falls back to + quota_project_id from GoogleCloudOptions. + """ + if self.quota_project_id: + return self.quota_project_id + if self.pipeline_options is not None: + return self.pipeline_options.view_as(GoogleCloudOptions).quota_project_id + return None + def _get_parent_project(self): """Returns the project that will be billed.""" if self.temp_table: @@ -1168,7 +1224,8 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): bq = bigquery_tools.BigQueryWrapper( temp_table_ref=(self.temp_table if self.temp_table else None), client=bigquery_tools.BigQueryWrapper._bigquery_client( - self.pipeline_options)) + self.pipeline_options), + quota_project_id=self._get_quota_project_id()) if self.query is not None: self._setup_temporary_dataset(bq) @@ -1201,7 +1258,7 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): if self.row_restriction is not None: requested_session.read_options.row_restriction = self.row_restriction - storage_client = bq_storage.BigQueryReadClient() + storage_client = _create_bq_storage_client(self._get_quota_project_id()) stream_count = 0 if desired_bundle_size > 0: table_size = self._get_table_size(bq, self.table_reference) @@ -1232,8 +1289,10 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): self.split_result = [ _CustomBigQueryStorageStreamSource( - stream.name, self.use_native_datetime, self.timeout) - for stream in read_session.streams + stream.name, + self.use_native_datetime, + self.timeout, + self._get_quota_project_id()) for stream in read_session.streams ] for source in self.split_result: @@ -1267,10 +1326,12 @@ def __init__( self, read_stream_name: str, use_native_datetime: Optional[bool] = True, - timeout: Optional[float] = None): + timeout: Optional[float] = None, + quota_project_id: Optional[str] = None): self.read_stream_name = read_stream_name self.use_native_datetime = use_native_datetime self.timeout = timeout + self.quota_project_id = quota_project_id def display_data(self): return { @@ -1293,7 +1354,10 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): return SourceBundle( weight=1.0, source=_CustomBigQueryStorageStreamSource( - self.read_stream_name, self.use_native_datetime), + self.read_stream_name, + self.use_native_datetime, + self.timeout, + self.quota_project_id), start_position=None, stop_position=None) @@ -1329,7 +1393,7 @@ def retry_delay_callback(delay): def read_arrow(self): - storage_client = bq_storage.BigQueryReadClient() + storage_client = _create_bq_storage_client(self.quota_project_id) read_rows_kwargs = {'retry_delay_callback': self.retry_delay_callback} if self.timeout is not None: read_rows_kwargs['timeout'] = self.timeout @@ -1348,7 +1412,7 @@ def read_arrow(self): yield py_row def read_avro(self): - storage_client = bq_storage.BigQueryReadClient() + storage_client = _create_bq_storage_client(self.quota_project_id) read_rows_kwargs = {'retry_delay_callback': self.retry_delay_callback} if self.timeout is not None: read_rows_kwargs['timeout'] = self.timeout diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py index 234c99847a44..c09cecd7bff9 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py @@ -65,6 +65,7 @@ from apache_beam.io.gcp.tests.bigquery_matcher import BigQueryTableMatcher from apache_beam.metrics.metric import Lineage from apache_beam.options import value_provider +from apache_beam.options.pipeline_options import GoogleCloudOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.value_provider import RuntimeValueProvider @@ -338,6 +339,10 @@ def test_repeatable_field_is_properly_converted(self): class TestReadFromBigQuery(unittest.TestCase): @classmethod def setUpClass(cls): + cls.env_patch = mock.patch.dict( + os.environ, {'GOOGLE_CLOUD_PROJECT': 'test-project'}) + cls.env_patch.start() + class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): @@ -351,6 +356,7 @@ def tearDown(self): @classmethod def tearDownClass(cls): + cls.env_patch.stop() # Unset the option added in setupClass to avoid interfere with other tests. # Force a gc so PipelineOptions.__subclass__() no longer contains it. del cls.UserDefinedOptions @@ -778,6 +784,196 @@ def test_read_all_lineage(self): ])) +@unittest.skipIf( + HttpError is None or gcp_bigquery is None, + 'GCP dependencies are not installed') +class TestReadFromBigQueryQuotaProject(unittest.TestCase): + """Tests for quota_project_id in ReadFromBigQuery sources.""" + def test_quota_project_id_from_pipeline_options(self): + """Test that quota_project_id is read from GoogleCloudOptions.""" + options = PipelineOptions(['--quota_project_id=my-billing-project']) + gcp_options = options.view_as(GoogleCloudOptions) + self.assertEqual(gcp_options.quota_project_id, 'my-billing-project') + + def test_quota_project_id_none_by_default_in_options(self): + """Test that quota_project_id is None by default in options.""" + options = PipelineOptions([]) + gcp_options = options.view_as(GoogleCloudOptions) + self.assertIsNone(gcp_options.quota_project_id) + + def test_export_source_explicit_quota_project(self): + """Test that explicit quota_project_id is stored in + _CustomBigQuerySource.""" + source = beam_bq._CustomBigQuerySource( + method=ReadFromBigQuery.Method.EXPORT, + table='project:dataset.table', + quota_project_id='my-billing-project') + self.assertEqual(source.quota_project_id, 'my-billing-project') + self.assertEqual(source._get_quota_project_id(), 'my-billing-project') + + def test_export_source_gets_quota_from_options(self): + """Test that _CustomBigQuerySource falls back to options for + quota_project_id.""" + options = PipelineOptions(['--quota_project_id=my-billing-project']) + source = beam_bq._CustomBigQuerySource( + method=ReadFromBigQuery.Method.EXPORT, + table='project:dataset.table', + pipeline_options=options) + self.assertEqual(source._get_quota_project_id(), 'my-billing-project') + + def test_export_source_explicit_overrides_options(self): + """Test that explicit quota_project_id overrides options.""" + options = PipelineOptions(['--quota_project_id=options-project']) + source = beam_bq._CustomBigQuerySource( + method=ReadFromBigQuery.Method.EXPORT, + table='project:dataset.table', + pipeline_options=options, + quota_project_id='explicit-project') + self.assertEqual(source._get_quota_project_id(), 'explicit-project') + + def test_storage_source_explicit_quota_project(self): + """Test that explicit quota_project_id is stored in + _CustomBigQueryStorageSource.""" + source = beam_bq._CustomBigQueryStorageSource( + method=ReadFromBigQuery.Method.DIRECT_READ, + table='project:dataset.table', + quota_project_id='my-billing-project') + self.assertEqual(source.quota_project_id, 'my-billing-project') + self.assertEqual(source._get_quota_project_id(), 'my-billing-project') + + def test_storage_source_gets_quota_from_options(self): + """Test that _CustomBigQueryStorageSource falls back to options.""" + options = PipelineOptions(['--quota_project_id=my-billing-project']) + source = beam_bq._CustomBigQueryStorageSource( + method=ReadFromBigQuery.Method.DIRECT_READ, + table='project:dataset.table', + pipeline_options=options) + self.assertEqual(source._get_quota_project_id(), 'my-billing-project') + + def test_quota_project_id_in_export_source_display_data(self): + """Test that quota_project_id appears in display data for export source.""" + source = beam_bq._CustomBigQuerySource( + method=ReadFromBigQuery.Method.EXPORT, + table='project:dataset.table', + quota_project_id='my-billing-project') + display_data = source.display_data() + self.assertEqual(display_data['quota_project_id'], 'my-billing-project') + + def test_quota_project_id_empty_in_display_data_when_not_set(self): + """Test that quota_project_id is empty string in display data + when not set.""" + options = PipelineOptions([]) + source = beam_bq._CustomBigQuerySource( + method=ReadFromBigQuery.Method.EXPORT, + table='project:dataset.table', + pipeline_options=options) + display_data = source.display_data() + self.assertEqual(display_data['quota_project_id'], '') + + def test_stream_source_stores_quota_project_id(self): + """Test that quota_project_id is stored in + _CustomBigQueryStorageStreamSource.""" + stream_source = beam_bq._CustomBigQueryStorageStreamSource( + read_stream_name='projects/p/locations/l/sessions/s/streams/stream1', + use_native_datetime=True, + timeout=30.0, + quota_project_id='my-billing-project') + self.assertEqual(stream_source.quota_project_id, 'my-billing-project') + + def test_stream_source_quota_project_id_none_by_default(self): + """Test that quota_project_id is None by default in stream source.""" + stream_source = beam_bq._CustomBigQueryStorageStreamSource( + read_stream_name='projects/p/locations/l/sessions/s/streams/stream1') + self.assertIsNone(stream_source.quota_project_id) + + def test_stream_source_split_preserves_quota_project_id(self): + """Test that split() preserves quota_project_id.""" + stream_source = beam_bq._CustomBigQueryStorageStreamSource( + read_stream_name='projects/p/locations/l/sessions/s/streams/stream1', + use_native_datetime=True, + timeout=30.0, + quota_project_id='my-billing-project') + bundle = stream_source.split(desired_bundle_size=0) + self.assertEqual(bundle.source.quota_project_id, 'my-billing-project') + self.assertEqual(bundle.source.timeout, 30.0) + self.assertEqual(bundle.source.use_native_datetime, True) + + @mock.patch('apache_beam.io.gcp.bigquery._create_bq_storage_client') + def test_stream_source_read_arrow_uses_quota_project( + self, mock_create_client): + """Test that read_arrow() uses _create_bq_storage_client + with quota_project_id.""" + mock_client = mock.MagicMock() + mock_create_client.return_value = mock_client + # Mock read_rows to return empty iterator + mock_client.read_rows.return_value.rows.return_value = iter([]) + + stream_source = beam_bq._CustomBigQueryStorageStreamSource( + read_stream_name='projects/p/locations/l/sessions/s/streams/stream1', + use_native_datetime=True, + quota_project_id='my-billing-project') + # Consume the iterator + list(stream_source.read_arrow()) + + mock_create_client.assert_called_once_with('my-billing-project') + + @mock.patch('apache_beam.io.gcp.bigquery._create_bq_storage_client') + def test_stream_source_read_avro_uses_quota_project(self, mock_create_client): + """Test that read_avro() uses _create_bq_storage_client + with quota_project_id.""" + mock_client = mock.MagicMock() + mock_create_client.return_value = mock_client + # Mock read_rows to return empty iterator + mock_client.read_rows.return_value = iter([]) + + stream_source = beam_bq._CustomBigQueryStorageStreamSource( + read_stream_name='projects/p/locations/l/sessions/s/streams/stream1', + use_native_datetime=False, + quota_project_id='my-billing-project') + # Consume the iterator + list(stream_source.read_avro()) + + mock_create_client.assert_called_once_with('my-billing-project') + + @mock.patch('apache_beam.io.gcp.bigquery.bq_storage') + @mock.patch('apache_beam.io.gcp.bigquery._LOGGER') + def test_create_bq_storage_client_logs_on_failure( + self, mock_logger, mock_bq_storage): + """Test that _create_bq_storage_client logs when quota project fails.""" + # Make google.auth.default raise a DefaultCredentialsError + from google.auth import exceptions as auth_exceptions + with mock.patch( + 'google.auth.default', + side_effect=auth_exceptions.DefaultCredentialsError('Auth error')): + beam_bq._create_bq_storage_client('my-billing-project') + + mock_logger.warning.assert_called_once() + warning_args = mock_logger.warning.call_args[0] + self.assertIn('Failed to apply quota project', warning_args[0]) + self.assertIn('my-billing-project', warning_args[1]) + + @mock.patch('apache_beam.io.gcp.bigquery.bq_storage') + def test_create_bq_storage_client_with_quota_project(self, mock_bq_storage): + """Test _create_bq_storage_client applies quota project to credentials.""" + mock_creds = mock.MagicMock(spec=['with_quota_project']) + mock_new_creds = mock.MagicMock() + mock_creds.with_quota_project.return_value = mock_new_creds + + with mock.patch('google.auth.default', return_value=(mock_creds, 'proj')): + beam_bq._create_bq_storage_client('my-billing-project') + + mock_creds.with_quota_project.assert_called_once_with('my-billing-project') + mock_bq_storage.BigQueryReadClient.assert_called_once_with( + credentials=mock_new_creds) + + @mock.patch('apache_beam.io.gcp.bigquery.bq_storage') + def test_create_bq_storage_client_without_quota_project( + self, mock_bq_storage): + """Test _create_bq_storage_client without quota project uses default.""" + beam_bq._create_bq_storage_client(None) + mock_bq_storage.BigQueryReadClient.assert_called_once_with() + + @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestBigQuerySink(unittest.TestCase): def test_table_spec_display_data(self): @@ -819,10 +1015,14 @@ def _cleanup_files(self): os.remove('insert_calls2') def setUp(self): + self.env_patch = mock.patch.dict( + os.environ, {'GOOGLE_CLOUD_PROJECT': 'test-project'}) + self.env_patch.start() self._cleanup_files() def tearDown(self): self._cleanup_files() + self.env_patch.stop() def test_noop_schema_parsing(self): expected_table_schema = None @@ -1238,6 +1438,13 @@ def test_copy_load_job_exception(self, exception_type, error_message): HttpError is None or exceptions is None, 'GCP dependencies are not installed') class BigQueryStreamingInsertsErrorHandling(unittest.TestCase): + def setUp(self): + self.env_patch = mock.patch.dict( + os.environ, {'GOOGLE_CLOUD_PROJECT': 'test-project'}) + self.env_patch.start() + + def tearDown(self): + self.env_patch.stop() # Running tests with a variety of exceptions from https://googleapis.dev # /python/google-api-core/latest/_modules/google/api_core/exceptions.html. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index ddab941f9278..abbfb8379563 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -60,6 +60,7 @@ from apache_beam.metrics import monitoring_infos from apache_beam.metrics.metric import Metrics from apache_beam.options import value_provider +from apache_beam.options.pipeline_options import GoogleCloudOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.transforms import DoFn from apache_beam.typehints.row_type import RowTypeConstraint @@ -359,11 +360,25 @@ class BigQueryWrapper(object): HISTOGRAM_METRIC_LOGGER = MetricLogger() - def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None): - self.client = client or BigQueryWrapper._bigquery_client(PipelineOptions()) - self.gcp_bq_client = client or gcp_bigquery.Client( - client_info=ClientInfo( - user_agent="apache-beam-%s" % apache_beam.__version__)) + def __init__( + self, + client=None, + temp_dataset_id=None, + temp_table_ref=None, + quota_project_id=None): + self.quota_project_id = quota_project_id + self.client = client or BigQueryWrapper._bigquery_client( + PipelineOptions(), quota_project_id=quota_project_id) + + # If the client is a mock (common in tests) or has the specific method + # we use, we use it as the gcp_bq_client to preserve backward + # compatibility for tests. Otherwise (e.g. it's a real apitools client), + # we create the correct google-cloud-bigquery client. + if client and hasattr(client, 'insert_rows_json'): + self.gcp_bq_client = client + else: + self.gcp_bq_client = BigQueryWrapper._gcp_bigquery_client( + quota_project_id=quota_project_id) self._unique_row_id = 0 # For testing scenarios where we pass in a client we do not want a @@ -1399,19 +1414,69 @@ def convert_row_to_dict(self, row, schema): @staticmethod def from_pipeline_options(pipeline_options: PipelineOptions): + """Create a BigQueryWrapper from pipeline options. + + Args: + pipeline_options: Pipeline options containing GCP configuration. + The quota_project_id is read from GoogleCloudOptions if set. + """ + quota_project_id = None + if pipeline_options is not None: + quota_project_id = pipeline_options.view_as( + GoogleCloudOptions).quota_project_id return BigQueryWrapper( - client=BigQueryWrapper._bigquery_client(pipeline_options)) + client=BigQueryWrapper._bigquery_client(pipeline_options), + quota_project_id=quota_project_id) @staticmethod - def _bigquery_client(pipeline_options: PipelineOptions): + def _bigquery_client( + pipeline_options: PipelineOptions, quota_project_id: str = None): + """Create a BigQuery API client from pipeline options. + + Args: + pipeline_options: Pipeline options for credentials. + quota_project_id: Optional quota project ID. If not provided, will be + extracted from pipeline_options. + """ + credentials = auth.get_service_credentials(pipeline_options) + # Use explicit quota_project_id if provided, otherwise get from options + if quota_project_id is None and pipeline_options is not None: + quota_project_id = pipeline_options.view_as( + GoogleCloudOptions).quota_project_id + if quota_project_id: + credentials = auth.with_quota_project(credentials, quota_project_id) return bigquery.BigqueryV2( http=get_new_http(), - credentials=auth.get_service_credentials(pipeline_options), + credentials=credentials, response_encoding='utf8', additional_http_headers={ "user-agent": "apache-beam-%s" % apache_beam.__version__ }) + @staticmethod + def _gcp_bigquery_client(quota_project_id: str = None): + """Create a google-cloud-bigquery Client with optional quota project.""" + credentials = None + + if quota_project_id: + # Get default credentials and apply quota project + try: + import google.auth + from google.auth import exceptions as auth_exceptions + credentials, _ = google.auth.default() + credentials = auth.with_quota_project(credentials, quota_project_id) + except (auth_exceptions.DefaultCredentialsError, AttributeError) as e: + _LOGGER.warning( + 'Failed to apply quota project %s to gcp-bigquery client: %s. ' + 'Falling back to default client.', + quota_project_id, + e) + + return gcp_bigquery.Client( + credentials=credentials, + client_info=ClientInfo( + user_agent="apache-beam-%s" % apache_beam.__version__)) + class RowAsDictJsonCoder(coders.Coder): """A coder for a table row (represented as a dict) to/from a JSON string. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py index 2594e6728e0e..ae835589205c 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py @@ -23,6 +23,7 @@ import json import logging import math +import os import re import unittest from typing import Optional @@ -45,6 +46,7 @@ from apache_beam.io.gcp.bigquery_tools import check_schema_equal from apache_beam.io.gcp.bigquery_tools import generate_bq_job_name from apache_beam.io.gcp.bigquery_tools import get_beam_typehints_from_tableschema +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper from apache_beam.io.gcp.bigquery_tools import parse_table_reference from apache_beam.io.gcp.bigquery_tools import parse_table_schema_from_json from apache_beam.io.gcp.internal.clients import bigquery @@ -236,14 +238,16 @@ def test_delete_dataset_retries_for_timeouts(self, patched_time_sleep): @mock.patch('google.cloud._http.JSONConnection.http') def test_user_agent_insert_all( self, http_mock, patched_skip_get_credentials, patched_sleep): - wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper() - try: - wrapper._insert_all_rows('p', 'd', 't', [{'name': 'any'}], None) - except: # pylint: disable=bare-except - # Ignore errors. The errors come from the fact that we did not mock - # the response from the API, so the overall insert_all_rows call fails - # soon after the BQ API is called. - pass + # Set GOOGLE_CLOUD_PROJECT to ensure Client creation succeeds in test env + with mock.patch.dict(os.environ, {'GOOGLE_CLOUD_PROJECT': 'test-project'}): + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper() + try: + wrapper._insert_all_rows('p', 'd', 't', [{'name': 'any'}], None) + except: # pylint: disable=bare-except + # Ignore errors. The errors come from the fact that we did not mock + # the response from the API, so the overall insert_all_rows call fails + # soon after the BQ API is called. + pass call = http_mock.request.mock_calls[-2] self.assertIn('apache-beam-', call[2]['headers']['User-Agent']) @@ -1106,8 +1110,6 @@ def test_geography_in_bigquery_type_mapping(self): def test_geography_field_conversion(self): """Test that GEOGRAPHY fields are converted correctly.""" - from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper - # Create a mock field with GEOGRAPHY type field = bigquery.TableFieldSchema() field.type = 'GEOGRAPHY' @@ -1229,8 +1231,6 @@ def test_geography_json_encoding(self): def test_geography_with_special_characters(self): """Test GEOGRAPHY values with special characters and geometries.""" - from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper - field = bigquery.TableFieldSchema() field.type = 'GEOGRAPHY' field.name = 'complex_geo' @@ -1248,6 +1248,57 @@ def test_geography_with_special_characters(self): self.assertIsInstance(result, str) +@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') +class TestBigQueryWrapperQuotaProject(unittest.TestCase): + """Tests for quota_project_id in BigQueryWrapper.""" + @mock.patch( + 'apache_beam.io.gcp.bigquery_tools.BigQueryWrapper._bigquery_client') + @mock.patch( + 'apache_beam.io.gcp.bigquery_tools.BigQueryWrapper._gcp_bigquery_client') + def test_quota_project_id_stored(self, mock_gcp_client, mock_bq_client): + """Test that quota_project_id is stored in BigQueryWrapper.""" + mock_bq_client.return_value = mock.Mock() + mock_gcp_client.return_value = mock.Mock() + + wrapper = BigQueryWrapper(quota_project_id='my-billing-project') + self.assertEqual(wrapper.quota_project_id, 'my-billing-project') + + @mock.patch( + 'apache_beam.io.gcp.bigquery_tools.BigQueryWrapper._bigquery_client') + @mock.patch( + 'apache_beam.io.gcp.bigquery_tools.BigQueryWrapper._gcp_bigquery_client') + def test_from_pipeline_options_reads_quota_from_options( + self, mock_gcp_client, mock_bq_client): + """Test from_pipeline_options reads quota_project_id from + GoogleCloudOptions.""" + from apache_beam.options.pipeline_options import PipelineOptions + + mock_bq_client.return_value = mock.Mock() + mock_gcp_client.return_value = mock.Mock() + + options = PipelineOptions(['--quota_project_id=my-billing-project']) + wrapper = BigQueryWrapper.from_pipeline_options(options) + + self.assertEqual(wrapper.quota_project_id, 'my-billing-project') + + @mock.patch( + 'apache_beam.io.gcp.bigquery_tools.BigQueryWrapper._bigquery_client') + @mock.patch( + 'apache_beam.io.gcp.bigquery_tools.BigQueryWrapper._gcp_bigquery_client') + def test_from_pipeline_options_none_when_not_set( + self, mock_gcp_client, mock_bq_client): + """Test from_pipeline_options returns None when quota_project_id not set.""" + from apache_beam.options.pipeline_options import PipelineOptions + + mock_bq_client.return_value = mock.Mock() + mock_gcp_client.return_value = mock.Mock() + + options = PipelineOptions([]) + wrapper = BigQueryWrapper.from_pipeline_options(options) + + self.assertIsNone(wrapper.quota_project_id) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 0e1012b2de65..cf653f12d75a 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1157,6 +1157,14 @@ def _add_argparse_args(cls, parser): action='store_true', help='Throttling counter in GcsIO is enabled by default. Set ' '--no_gcsio_throttling_counter to avoid it.') + parser.add_argument( + '--quota_project_id', + default=None, + help='GCP project ID to use for quota and billing purposes. ' + 'If not specified, the project associated with the credentials ' + 'will be used for quota. This is useful when running pipelines ' + 'that access resources in a different project than the one ' + 'associated with the credentials.') parser.add_argument( '--enable_gcsio_blob_generation', default=False,