Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@

## New Features / Improvements

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Added support for setting quota project ID in BigQuery read operations via `--quota_project_id` pipeline option or `quota_project_id` parameter in ReadFromBigQuery transform (Python) ([#37431](https://github.com/apache/beam/issues/37431)).

## Breaking Changes

Expand Down
41 changes: 41 additions & 0 deletions sdks/python/apache_beam/internal/gcp/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,47 @@ def get_service_credentials(pipeline_options):
return _Credentials.get_service_credentials(pipeline_options)


def with_quota_project(credentials, quota_project_id):
"""For internal use only; no backwards-compatibility guarantees.

Apply a quota project to credentials if supported.

The quota project is used to bill API requests to a specific GCP project,
separate from the project that owns the service account or data.

Args:
credentials: The credentials object (either _ApitoolsCredentialsAdapter
or a google.auth credentials object).
quota_project_id: The GCP project ID to use for quota and billing.

Returns:
Credentials with the quota project applied, or the original credentials
if quota project is not supported or credentials is None.
"""
if credentials is None or quota_project_id is None:
return credentials

# Get the underlying google-auth credentials if wrapped
if hasattr(credentials, 'get_google_auth_credentials'):
underlying_creds = credentials.get_google_auth_credentials()
else:
underlying_creds = credentials

# Apply quota project if supported
if hasattr(underlying_creds, 'with_quota_project'):
new_creds = underlying_creds.with_quota_project(quota_project_id)
# Re-wrap if the original was wrapped
if hasattr(credentials, 'get_google_auth_credentials'):
return _ApitoolsCredentialsAdapter(new_creds)
return new_creds

_LOGGER.warning(
'Credentials of type %s do not support quota project. '
'The quota_project_id parameter will be ignored.',
type(underlying_creds).__name__)
return credentials


if _GOOGLE_AUTH_AVAILABLE:

class _ApitoolsCredentialsAdapter:
Expand Down
60 changes: 60 additions & 0 deletions sdks/python/apache_beam/internal/gcp/auth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,65 @@ def raise_(scopes=None):
auth._LOGGER.removeHandler(loggerHandler)


@unittest.skipIf(gauth is None, 'Google Auth dependencies are not installed')
class WithQuotaProjectTest(unittest.TestCase):
"""Tests for with_quota_project function."""
def test_with_quota_project_returns_credentials_unchanged_when_none(self):
"""Test that None credentials are returned unchanged."""
result = auth.with_quota_project(None, 'my-project')
self.assertIsNone(result)

def test_with_quota_project_returns_credentials_unchanged_when_no_quota(self):
"""Test that credentials are returned unchanged when
quota_project_id is None."""
mock_creds = mock.MagicMock()
result = auth.with_quota_project(mock_creds, None)
self.assertEqual(result, mock_creds)
mock_creds.with_quota_project.assert_not_called()

@mock.patch('apache_beam.internal.gcp.auth._ApitoolsCredentialsAdapter')
def test_with_quota_project_applies_quota_to_wrapped_credentials(
self, mock_adapter_class):
"""Test that quota project is applied to wrapped credentials."""
mock_inner_creds = mock.MagicMock()
mock_new_creds = mock.MagicMock()
mock_inner_creds.with_quota_project.return_value = mock_new_creds

mock_adapter = mock.MagicMock()
mock_adapter.get_google_auth_credentials.return_value = mock_inner_creds

mock_adapter_instance = mock.MagicMock()
mock_adapter_class.return_value = mock_adapter_instance

result = auth.with_quota_project(mock_adapter, 'my-billing-project')

mock_inner_creds.with_quota_project.assert_called_once_with(
'my-billing-project')
# Result should be a new adapter wrapping the new credentials
mock_adapter_class.assert_called_once_with(mock_new_creds)
self.assertEqual(result, mock_adapter_instance)

def test_with_quota_project_applies_quota_to_direct_credentials(self):
"""Test that quota project is applied to direct credentials."""
mock_creds = mock.MagicMock(spec=['with_quota_project'])
mock_new_creds = mock.MagicMock()
mock_creds.with_quota_project.return_value = mock_new_creds

result = auth.with_quota_project(mock_creds, 'my-billing-project')

mock_creds.with_quota_project.assert_called_once_with('my-billing-project')
self.assertEqual(result, mock_new_creds)

def test_with_quota_project_returns_original_when_not_supported(self):
"""Test that original credentials are returned when
with_quota_project is not supported."""
# Create a mock without with_quota_project method
mock_creds = mock.MagicMock(spec=[])

result = auth.with_quota_project(mock_creds, 'my-billing-project')

self.assertEqual(result, mock_creds)


if __name__ == '__main__':
unittest.main()
86 changes: 75 additions & 11 deletions sdks/python/apache_beam/io/gcp/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def chain_after(result):
import apache_beam as beam
from apache_beam import coders
from apache_beam import pvalue
from apache_beam.internal.gcp import auth
from apache_beam.internal.gcp.json_value import from_json_value
from apache_beam.internal.gcp.json_value import to_json_value
from apache_beam.io import range_trackers
Expand Down Expand Up @@ -662,7 +663,8 @@ def __init__(
step_name=None,
unique_id=None,
temp_dataset=None,
query_priority=BigQueryQueryPriority.BATCH):
query_priority=BigQueryQueryPriority.BATCH,
quota_project_id=None):
if table is not None and query is not None:
raise ValueError(
'Both a BigQuery table and a query were specified.'
Expand Down Expand Up @@ -696,6 +698,7 @@ def __init__(
self.use_json_exports = use_json_exports
self.temp_dataset = temp_dataset
self.query_priority = query_priority
self.quota_project_id = quota_project_id
self._job_name = job_name or 'BQ_EXPORT_JOB'
self._step_name = step_name
self._source_uuid = unique_id
Expand All @@ -715,6 +718,7 @@ def display_data(self):
'use_legacy_sql': self.use_legacy_sql,
'bigquery_job_labels': json.dumps(self.bigquery_job_labels),
'export_file_format': export_format,
'quota_project_id': self._get_quota_project_id() or '',
'launchesBigQueryJobs': DisplayDataItem(
True, label="This Dataflow job launches bigquery jobs."),
}
Expand Down Expand Up @@ -782,6 +786,18 @@ def _get_project(self):
project = self.project
return project

def _get_quota_project_id(self):
"""Returns the quota project ID for API calls.

Prefers the explicit quota_project_id parameter, falls back to
quota_project_id from GoogleCloudOptions.
"""
if self.quota_project_id:
return self.quota_project_id
if self.options is not None:
return self.options.view_as(GoogleCloudOptions).quota_project_id
return None

def _create_source(self, path, coder):
if not self.use_json_exports:
return create_avro_source(path, validate=self.validate)
Expand All @@ -799,7 +815,8 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None):
bq = bigquery_tools.BigQueryWrapper(
temp_dataset_id=(
self.temp_dataset.datasetId if self.temp_dataset else None),
client=bigquery_tools.BigQueryWrapper._bigquery_client(self.options))
client=bigquery_tools.BigQueryWrapper._bigquery_client(self.options),
quota_project_id=self._get_quota_project_id())

if self.query is not None:
self._setup_temporary_dataset(bq)
Expand Down Expand Up @@ -932,6 +949,31 @@ def _export_files(self, bq):
return table.schema, metadata_list


def _create_bq_storage_client(quota_project_id=None):
"""Create a BigQueryReadClient with optional quota project.

Args:
quota_project_id: Optional GCP project ID to use for quota and billing.

Returns:
A BigQueryReadClient instance.
"""
if quota_project_id:
try:
import google.auth
from google.auth import exceptions as auth_exceptions
credentials, _ = google.auth.default()
credentials = auth.with_quota_project(credentials, quota_project_id)
return bq_storage.BigQueryReadClient(credentials=credentials)
except (auth_exceptions.DefaultCredentialsError, AttributeError) as e:
_LOGGER.warning(
'Failed to apply quota project %s to BigQuery Storage client: %s. '
'Falling back to default client.',
quota_project_id,
e)
return bq_storage.BigQueryReadClient()


class _CustomBigQueryStorageSource(BoundedSource):
"""A base class for BoundedSource implementations which read from BigQuery
using the BigQuery Storage API.
Expand Down Expand Up @@ -989,7 +1031,8 @@ def __init__(
temp_dataset: Optional[DatasetReference] = None,
temp_table: Optional[TableReference] = None,
use_native_datetime: Optional[bool] = False,
timeout: Optional[float] = None):
timeout: Optional[float] = None,
quota_project_id: Optional[str] = None):

if table is not None and query is not None:
raise ValueError(
Expand Down Expand Up @@ -1028,6 +1071,7 @@ def __init__(
self._job_name = job_name or 'BQ_DIRECT_READ_JOB'
self._step_name = step_name
self._source_uuid = unique_id
self.quota_project_id = quota_project_id

def _get_project(self):
"""Returns the project that queries and exports will be billed to."""
Expand All @@ -1039,6 +1083,18 @@ def _get_project(self):
return project
return self.project

def _get_quota_project_id(self):
"""Returns the quota project ID for API calls.

Prefers the explicit quota_project_id parameter, falls back to
quota_project_id from GoogleCloudOptions.
"""
if self.quota_project_id:
return self.quota_project_id
if self.pipeline_options is not None:
return self.pipeline_options.view_as(GoogleCloudOptions).quota_project_id
return None

def _get_parent_project(self):
"""Returns the project that will be billed."""
if self.temp_table:
Expand Down Expand Up @@ -1168,7 +1224,8 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None):
bq = bigquery_tools.BigQueryWrapper(
temp_table_ref=(self.temp_table if self.temp_table else None),
client=bigquery_tools.BigQueryWrapper._bigquery_client(
self.pipeline_options))
self.pipeline_options),
quota_project_id=self._get_quota_project_id())

if self.query is not None:
self._setup_temporary_dataset(bq)
Expand Down Expand Up @@ -1201,7 +1258,7 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None):
if self.row_restriction is not None:
requested_session.read_options.row_restriction = self.row_restriction

storage_client = bq_storage.BigQueryReadClient()
storage_client = _create_bq_storage_client(self._get_quota_project_id())
stream_count = 0
if desired_bundle_size > 0:
table_size = self._get_table_size(bq, self.table_reference)
Expand Down Expand Up @@ -1232,8 +1289,10 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None):

self.split_result = [
_CustomBigQueryStorageStreamSource(
stream.name, self.use_native_datetime, self.timeout)
for stream in read_session.streams
stream.name,
self.use_native_datetime,
self.timeout,
self._get_quota_project_id()) for stream in read_session.streams
]

for source in self.split_result:
Expand Down Expand Up @@ -1267,10 +1326,12 @@ def __init__(
self,
read_stream_name: str,
use_native_datetime: Optional[bool] = True,
timeout: Optional[float] = None):
timeout: Optional[float] = None,
quota_project_id: Optional[str] = None):
self.read_stream_name = read_stream_name
self.use_native_datetime = use_native_datetime
self.timeout = timeout
self.quota_project_id = quota_project_id

def display_data(self):
return {
Expand All @@ -1293,7 +1354,10 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None):
return SourceBundle(
weight=1.0,
source=_CustomBigQueryStorageStreamSource(
self.read_stream_name, self.use_native_datetime),
self.read_stream_name,
self.use_native_datetime,
self.timeout,
self.quota_project_id),
start_position=None,
stop_position=None)

Expand Down Expand Up @@ -1329,7 +1393,7 @@ def retry_delay_callback(delay):

def read_arrow(self):

storage_client = bq_storage.BigQueryReadClient()
storage_client = _create_bq_storage_client(self.quota_project_id)
read_rows_kwargs = {'retry_delay_callback': self.retry_delay_callback}
if self.timeout is not None:
read_rows_kwargs['timeout'] = self.timeout
Expand All @@ -1348,7 +1412,7 @@ def read_arrow(self):
yield py_row

def read_avro(self):
storage_client = bq_storage.BigQueryReadClient()
storage_client = _create_bq_storage_client(self.quota_project_id)
read_rows_kwargs = {'retry_delay_callback': self.retry_delay_callback}
if self.timeout is not None:
read_rows_kwargs['timeout'] = self.timeout
Expand Down
Loading
Loading