Skip to content

Commit c4496fd

Browse files
authored
feat: support request and transaction tags (#558)
* feat: support transaction and request tags in dbapi Adds support for setting transaction tags and request tags in dbapi. This makes these options available to frameworks that depend on dbapi, like SQLAlchemy and Django. Towards #525 * test: add test for transaction tags * test: fix test cases
1 parent 39fb556 commit c4496fd

File tree

4 files changed

+180
-1
lines changed

4 files changed

+180
-1
lines changed

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,14 @@ def pre_exec(self):
180180
if priority is not None:
181181
self._dbapi_connection.connection.request_priority = priority
182182

183+
transaction_tag = self.execution_options.get("transaction_tag")
184+
if transaction_tag:
185+
self._dbapi_connection.connection.transaction_tag = transaction_tag
186+
187+
request_tag = self.execution_options.get("request_tag")
188+
if request_tag:
189+
self.cursor.request_tag = request_tag
190+
183191
def fire_sequence(self, seq, type_):
184192
"""Builds a statement for fetching next value of the sequence."""
185193
return self._execute_scalar(
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sqlalchemy import String, BigInteger
16+
from sqlalchemy.orm import DeclarativeBase
17+
from sqlalchemy.orm import Mapped
18+
from sqlalchemy.orm import mapped_column
19+
20+
21+
class Base(DeclarativeBase):
22+
pass
23+
24+
25+
class Singer(Base):
26+
__tablename__ = "singers"
27+
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
28+
name: Mapped[str] = mapped_column(String)

test/mockserver_tests/test_basics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,12 @@ def test_sqlalchemy_select1(self):
7474
with engine.connect().execution_options(
7575
isolation_level="AUTOCOMMIT"
7676
) as connection:
77-
results = connection.execute(select(1)).fetchall()
77+
results = connection.execute(
78+
select(1).execution_options(request_tag="my-tag")
79+
).fetchall()
7880
self.verify_select1(results)
81+
request: ExecuteSqlRequest = self.spanner_service.requests[1]
82+
eq_("my-tag", request.request_options.request_tag)
7983

8084
def test_sqlalchemy_select_now(self):
8185
now = datetime.datetime.now(datetime.UTC)

test/mockserver_tests/test_tags.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sqlalchemy import create_engine, select
16+
from sqlalchemy.orm import Session
17+
from sqlalchemy.testing import eq_, is_instance_of
18+
from google.cloud.spanner_v1 import (
19+
FixedSizePool,
20+
BatchCreateSessionsRequest,
21+
ExecuteSqlRequest,
22+
BeginTransactionRequest,
23+
CommitRequest,
24+
)
25+
from test.mockserver_tests.mock_server_test_base import (
26+
MockServerTestBase,
27+
add_update_count,
28+
)
29+
from test.mockserver_tests.mock_server_test_base import add_result
30+
import google.cloud.spanner_v1.types.type as spanner_type
31+
import google.cloud.spanner_v1.types.result_set as result_set
32+
33+
34+
class TestStaleReads(MockServerTestBase):
35+
def test_request_tag(self):
36+
from test.mockserver_tests.tags_model import Singer
37+
38+
add_singer_query_result("SELECT singers.id, singers.name \n" + "FROM singers")
39+
engine = create_engine(
40+
"spanner:///projects/p/instances/i/databases/d",
41+
connect_args={"client": self.client, "pool": FixedSizePool(size=10)},
42+
)
43+
44+
with Session(engine.execution_options(read_only=True)) as session:
45+
# Execute two queries in a read-only transaction.
46+
session.scalars(
47+
select(Singer).execution_options(request_tag="my-tag-1")
48+
).all()
49+
session.scalars(
50+
select(Singer).execution_options(request_tag="my-tag-2")
51+
).all()
52+
53+
# Verify the requests that we got.
54+
requests = self.spanner_service.requests
55+
eq_(4, len(requests))
56+
is_instance_of(requests[0], BatchCreateSessionsRequest)
57+
is_instance_of(requests[1], BeginTransactionRequest)
58+
is_instance_of(requests[2], ExecuteSqlRequest)
59+
is_instance_of(requests[3], ExecuteSqlRequest)
60+
# Verify that we got a request tag for the queries.
61+
eq_("my-tag-1", requests[2].request_options.request_tag)
62+
eq_("my-tag-2", requests[3].request_options.request_tag)
63+
64+
def test_transaction_tag(self):
65+
from test.mockserver_tests.tags_model import Singer
66+
67+
add_singer_query_result("SELECT singers.id, singers.name\n" + "FROM singers")
68+
add_update_count("INSERT INTO singers (id, name) VALUES (@a0, @a1)", 1)
69+
engine = create_engine(
70+
"spanner:///projects/p/instances/i/databases/d",
71+
connect_args={"client": self.client, "pool": FixedSizePool(size=10)},
72+
)
73+
74+
with Session(
75+
engine.execution_options(transaction_tag="my-transaction-tag")
76+
) as session:
77+
# Execute a query and an insert statement in a read/write transaction.
78+
session.scalars(
79+
select(Singer).execution_options(request_tag="my-tag-1")
80+
).all()
81+
session.add(Singer(id=1, name="Some Singer"))
82+
session.commit()
83+
84+
# Verify the requests that we got.
85+
requests = self.spanner_service.requests
86+
eq_(5, len(requests))
87+
is_instance_of(requests[0], BatchCreateSessionsRequest)
88+
is_instance_of(requests[1], BeginTransactionRequest)
89+
is_instance_of(requests[2], ExecuteSqlRequest)
90+
is_instance_of(requests[3], ExecuteSqlRequest)
91+
is_instance_of(requests[4], CommitRequest)
92+
for request in requests[2:]:
93+
eq_("my-transaction-tag", request.request_options.transaction_tag)
94+
95+
96+
def add_singer_query_result(sql: str):
97+
result = result_set.ResultSet(
98+
dict(
99+
metadata=result_set.ResultSetMetadata(
100+
dict(
101+
row_type=spanner_type.StructType(
102+
dict(
103+
fields=[
104+
spanner_type.StructType.Field(
105+
dict(
106+
name="singers_id",
107+
type=spanner_type.Type(
108+
dict(code=spanner_type.TypeCode.INT64)
109+
),
110+
)
111+
),
112+
spanner_type.StructType.Field(
113+
dict(
114+
name="singers_name",
115+
type=spanner_type.Type(
116+
dict(code=spanner_type.TypeCode.STRING)
117+
),
118+
)
119+
),
120+
]
121+
)
122+
)
123+
)
124+
),
125+
)
126+
)
127+
result.rows.extend(
128+
[
129+
(
130+
"1",
131+
"Jane Doe",
132+
),
133+
(
134+
"2",
135+
"John Doe",
136+
),
137+
]
138+
)
139+
add_result(sql, result)

0 commit comments

Comments
 (0)