Skip to content

Commit cf4f63d

Browse files
authored
Merge branch 'main' into multiplexed-sessions
2 parents 70c7967 + 6ca9b43 commit cf4f63d

19 files changed

+1165
-231
lines changed

.github/.OwlBot.lock.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
# limitations under the License.
1414
docker:
1515
image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
16-
digest: sha256:5581906b957284864632cde4e9c51d1cc66b0094990b27e689132fe5cd036046
17-
# created: 2025-03-05
16+
digest: sha256:25de45b58e52021d3a24a6273964371a97a4efeefe6ad3845a64e697c63b6447
17+
# created: 2025-04-14T14:34:43.260858345Z

google/cloud/spanner_dbapi/client_side_statement_executor.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import TYPE_CHECKING
14+
from typing import TYPE_CHECKING, Union
15+
from google.cloud.spanner_v1 import TransactionOptions
1516

1617
if TYPE_CHECKING:
1718
from google.cloud.spanner_dbapi.cursor import Cursor
@@ -58,7 +59,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
5859
connection.commit()
5960
return None
6061
if statement_type == ClientSideStatementType.BEGIN:
61-
connection.begin()
62+
connection.begin(isolation_level=_get_isolation_level(parsed_statement))
6263
return None
6364
if statement_type == ClientSideStatementType.ROLLBACK:
6465
connection.rollback()
@@ -121,3 +122,19 @@ def _get_streamed_result_set(column_name, type_code, column_values):
121122
column_values_pb.append(_make_value_pb(column_value))
122123
result_set.values.extend(column_values_pb)
123124
return StreamedResultSet(iter([result_set]))
125+
126+
127+
def _get_isolation_level(
128+
statement: ParsedStatement,
129+
) -> Union[TransactionOptions.IsolationLevel, None]:
130+
if (
131+
statement.client_side_statement_params is None
132+
or len(statement.client_side_statement_params) == 0
133+
):
134+
return None
135+
level = statement.client_side_statement_params[0]
136+
if not isinstance(level, str) or level == "":
137+
return None
138+
# Replace (duplicate) whitespaces in the string with an underscore.
139+
level = "_".join(level.split()).upper()
140+
return TransactionOptions.IsolationLevel[level]

google/cloud/spanner_dbapi/client_side_statement_parser.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
Statement,
2222
)
2323

24-
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
24+
RE_BEGIN = re.compile(
25+
r"^\s*(?:BEGIN|START)(?:\s+TRANSACTION)?(?:\s+ISOLATION\s+LEVEL\s+(REPEATABLE\s+READ|SERIALIZABLE))?\s*$",
26+
re.IGNORECASE,
27+
)
2528
RE_COMMIT = re.compile(r"^\s*(COMMIT)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
2629
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
2730
RE_SHOW_COMMIT_TIMESTAMP = re.compile(
@@ -68,6 +71,10 @@ def parse_stmt(query):
6871
elif RE_START_BATCH_DML.match(query):
6972
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
7073
elif RE_BEGIN.match(query):
74+
match = re.search(RE_BEGIN, query)
75+
isolation_level = match.group(1)
76+
if isolation_level is not None:
77+
client_side_statement_params.append(isolation_level)
7178
client_side_statement_type = ClientSideStatementType.BEGIN
7279
elif RE_RUN_BATCH.match(query):
7380
client_side_statement_type = ClientSideStatementType.RUN_BATCH

google/cloud/spanner_v1/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from .types.type import Type
6464
from .types.type import TypeAnnotationCode
6565
from .types.type import TypeCode
66-
from .data_types import JsonObject
66+
from .data_types import JsonObject, Interval
6767
from .transaction import BatchTransactionId, DefaultTransactionOptions
6868

6969
from google.cloud.spanner_v1 import param_types
@@ -145,6 +145,7 @@
145145
"TypeCode",
146146
# Custom spanner related data types
147147
"JsonObject",
148+
"Interval",
148149
# google.cloud.spanner_v1.services
149150
"SpannerClient",
150151
"SpannerAsyncClient",

google/cloud/spanner_v1/_helpers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from google.cloud._helpers import _date_from_iso8601_date
3232
from google.cloud.spanner_v1 import TypeCode
3333
from google.cloud.spanner_v1 import ExecuteSqlRequest
34-
from google.cloud.spanner_v1 import JsonObject
34+
from google.cloud.spanner_v1 import JsonObject, Interval
3535
from google.cloud.spanner_v1 import TransactionOptions
3636
from google.cloud.spanner_v1.request_id_header import with_request_id
3737
from google.rpc.error_details_pb2 import RetryInfo
@@ -251,6 +251,8 @@ def _make_value_pb(value):
251251
return Value(null_value="NULL_VALUE")
252252
else:
253253
return Value(string_value=base64.b64encode(value))
254+
if isinstance(value, Interval):
255+
return Value(string_value=str(value))
254256

255257
raise ValueError("Unknown type: %s" % (value,))
256258

@@ -367,6 +369,8 @@ def _get_type_decoder(field_type, field_name, column_info=None):
367369
for item_field in field_type.struct_type.fields
368370
]
369371
return lambda value_pb: _parse_struct(value_pb, element_decoders)
372+
elif type_code == TypeCode.INTERVAL:
373+
return _parse_interval
370374
else:
371375
raise ValueError("Unknown type: %s" % (field_type,))
372376

@@ -473,6 +477,13 @@ def _parse_nullable(value_pb, decoder):
473477
return decoder(value_pb)
474478

475479

480+
def _parse_interval(value_pb):
481+
"""Parse a Value protobuf containing an interval."""
482+
if hasattr(value_pb, "string_value"):
483+
return Interval.from_str(value_pb.string_value)
484+
return Interval.from_str(value_pb)
485+
486+
476487
class _SessionWrapper(object):
477488
"""Base class for objects wrapping a session.
478489

google/cloud/spanner_v1/data_types.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
import json
1818
import types
19-
19+
import re
20+
from dataclasses import dataclass
2021
from google.protobuf.message import Message
2122
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
2223

@@ -97,6 +98,152 @@ def serialize(self):
9798
return json.dumps(self, sort_keys=True, separators=(",", ":"))
9899

99100

101+
@dataclass
102+
class Interval:
103+
"""Represents a Spanner INTERVAL type.
104+
105+
An interval is a combination of months, days and nanoseconds.
106+
Internally, Spanner supports Interval value with the following range of individual fields:
107+
months: [-120000, 120000]
108+
days: [-3660000, 3660000]
109+
nanoseconds: [-316224000000000000000, 316224000000000000000]
110+
"""
111+
112+
months: int = 0
113+
days: int = 0
114+
nanos: int = 0
115+
116+
def __str__(self) -> str:
117+
"""Returns the ISO8601 duration format string representation."""
118+
result = ["P"]
119+
120+
# Handle years and months
121+
if self.months:
122+
is_negative = self.months < 0
123+
abs_months = abs(self.months)
124+
years, months = divmod(abs_months, 12)
125+
if years:
126+
result.append(f"{'-' if is_negative else ''}{years}Y")
127+
if months:
128+
result.append(f"{'-' if is_negative else ''}{months}M")
129+
130+
# Handle days
131+
if self.days:
132+
result.append(f"{self.days}D")
133+
134+
# Handle time components
135+
if self.nanos:
136+
result.append("T")
137+
nanos = abs(self.nanos)
138+
is_negative = self.nanos < 0
139+
140+
# Convert to hours, minutes, seconds
141+
nanos_per_hour = 3600000000000
142+
hours, nanos = divmod(nanos, nanos_per_hour)
143+
if hours:
144+
if is_negative:
145+
result.append("-")
146+
result.append(f"{hours}H")
147+
148+
nanos_per_minute = 60000000000
149+
minutes, nanos = divmod(nanos, nanos_per_minute)
150+
if minutes:
151+
if is_negative:
152+
result.append("-")
153+
result.append(f"{minutes}M")
154+
155+
nanos_per_second = 1000000000
156+
seconds, nanos_fraction = divmod(nanos, nanos_per_second)
157+
158+
if seconds or nanos_fraction:
159+
if is_negative:
160+
result.append("-")
161+
if seconds:
162+
result.append(str(seconds))
163+
elif nanos_fraction:
164+
result.append("0")
165+
166+
if nanos_fraction:
167+
nano_str = f"{nanos_fraction:09d}"
168+
trimmed = nano_str.rstrip("0")
169+
if len(trimmed) <= 3:
170+
while len(trimmed) < 3:
171+
trimmed += "0"
172+
elif len(trimmed) <= 6:
173+
while len(trimmed) < 6:
174+
trimmed += "0"
175+
else:
176+
while len(trimmed) < 9:
177+
trimmed += "0"
178+
result.append(f".{trimmed}")
179+
result.append("S")
180+
181+
if len(result) == 1:
182+
result.append("0Y") # Special case for zero interval
183+
184+
return "".join(result)
185+
186+
@classmethod
187+
def from_str(cls, s: str) -> "Interval":
188+
"""Parse an ISO8601 duration format string into an Interval."""
189+
pattern = r"^P(-?\d+Y)?(-?\d+M)?(-?\d+D)?(T(-?\d+H)?(-?\d+M)?(-?((\d+([.,]\d{1,9})?)|([.,]\d{1,9}))S)?)?$"
190+
match = re.match(pattern, s)
191+
if not match or len(s) == 1:
192+
raise ValueError(f"Invalid interval format: {s}")
193+
194+
parts = match.groups()
195+
if not any(parts[:3]) and not parts[3]:
196+
raise ValueError(
197+
f"Invalid interval format: at least one component (Y/M/D/H/M/S) is required: {s}"
198+
)
199+
200+
if parts[3] == "T" and not any(parts[4:7]):
201+
raise ValueError(
202+
f"Invalid interval format: time designator 'T' present but no time components specified: {s}"
203+
)
204+
205+
def parse_num(s: str, suffix: str) -> int:
206+
if not s:
207+
return 0
208+
return int(s.rstrip(suffix))
209+
210+
years = parse_num(parts[0], "Y")
211+
months = parse_num(parts[1], "M")
212+
total_months = years * 12 + months
213+
214+
days = parse_num(parts[2], "D")
215+
216+
nanos = 0
217+
if parts[3]: # Has time component
218+
# Convert hours to nanoseconds
219+
hours = parse_num(parts[4], "H")
220+
nanos += hours * 3600000000000
221+
222+
# Convert minutes to nanoseconds
223+
minutes = parse_num(parts[5], "M")
224+
nanos += minutes * 60000000000
225+
226+
# Handle seconds and fractional seconds
227+
if parts[6]:
228+
seconds = parts[6].rstrip("S")
229+
if "," in seconds:
230+
seconds = seconds.replace(",", ".")
231+
232+
if "." in seconds:
233+
sec_parts = seconds.split(".")
234+
whole_seconds = sec_parts[0] if sec_parts[0] else "0"
235+
nanos += int(whole_seconds) * 1000000000
236+
frac = sec_parts[1][:9].ljust(9, "0")
237+
frac_nanos = int(frac)
238+
if seconds.startswith("-"):
239+
frac_nanos = -frac_nanos
240+
nanos += frac_nanos
241+
else:
242+
nanos += int(seconds) * 1000000000
243+
244+
return cls(months=total_months, days=days, nanos=nanos)
245+
246+
100247
def _proto_message(bytes_val, proto_message_object):
101248
"""Helper for :func:`get_proto_message`.
102249
parses serialized protocol buffer bytes data into proto message.

google/cloud/spanner_v1/param_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
PG_NUMERIC = Type(code=TypeCode.NUMERIC, type_annotation=TypeAnnotationCode.PG_NUMERIC)
3737
PG_JSONB = Type(code=TypeCode.JSON, type_annotation=TypeAnnotationCode.PG_JSONB)
3838
PG_OID = Type(code=TypeCode.INT64, type_annotation=TypeAnnotationCode.PG_OID)
39+
INTERVAL = Type(code=TypeCode.INTERVAL)
3940

4041

4142
def Array(element_type):

google/cloud/spanner_v1/streamed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def _merge_struct(lhs, rhs, type_):
391391
TypeCode.NUMERIC: _merge_string,
392392
TypeCode.JSON: _merge_string,
393393
TypeCode.PROTO: _merge_string,
394+
TypeCode.INTERVAL: _merge_string,
394395
TypeCode.ENUM: _merge_string,
395396
}
396397

0 commit comments

Comments
 (0)