Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions datastore/gcloud/aio/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class Datastore:
_api_root: str
_api_is_dev: bool

Timeout = Union[int, float]

def __init__(
self, project: Optional[str] = None,
service_file: Optional[Union[str, IO[AnyStr]]] = None,
Expand Down Expand Up @@ -159,7 +161,7 @@ def make_mutation(
async def allocateIds(
self, keys: List[Key],
session: Optional[Session] = None,
timeout: int = 10,
timeout: Timeout = 10,
) -> List[Key]:
project = await self.project()
url = f'{self._api_root}/projects/{project}:allocateIds'
Expand Down Expand Up @@ -187,7 +189,7 @@ async def allocateIds(
# TODO: support readwrite vs readonly transaction types
async def beginTransaction(
self, session: Optional[Session] = None,
timeout: int = 10,
timeout: Timeout = 10,
) -> str:
project = await self.project()
url = f'{self._api_root}/projects/{project}:beginTransaction'
Expand All @@ -210,7 +212,7 @@ async def commit(
transaction: Optional[str] = None,
mode: Mode = Mode.TRANSACTIONAL,
session: Optional[Session] = None,
timeout: int = 10,
timeout: Timeout = 10,
) -> Dict[str, Any]:
project = await self.project()
url = f'{self._api_root}/projects/{project}:commit'
Expand Down Expand Up @@ -249,7 +251,7 @@ async def export(
namespaces: Optional[List[str]] = None,
labels: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
timeout: int = 10,
timeout: Timeout = 10,
) -> DatastoreOperation:
project = await self.project()
url = f'{self._api_root}/projects/{project}:export'
Expand Down Expand Up @@ -282,7 +284,7 @@ async def export(
async def get_datastore_operation(
self, name: str,
session: Optional[Session] = None,
timeout: int = 10,
timeout: Timeout = 10,
) -> DatastoreOperation:
url = f'{self._api_root}/{name}'

Expand All @@ -297,19 +299,21 @@ async def get_datastore_operation(

return self.datastore_operation_kind.from_repr(data)

# pylint: disable=too-many-locals
# https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/lookup
async def lookup(
self, keys: List[Key],
transaction: Optional[str] = None,
newTransaction: Optional[TransactionOptions] = None,
consistency: Consistency = Consistency.STRONG,
session: Optional[Session] = None, timeout: int = 10,
read_time: Optional[str] = None,
session: Optional[Session] = None, timeout: Timeout = 10,
) -> LookUpResult:
project = await self.project()
url = f'{self._api_root}/projects/{project}:lookup'

read_options = self._build_read_options(
consistency, newTransaction, transaction)
consistency, newTransaction, transaction, read_time)

payload = json.dumps({
'keys': [k.to_repr() for k in keys],
Expand Down Expand Up @@ -350,27 +354,35 @@ def _build_lookup_result(self, data: Dict[str, Any]) -> LookUpResult:
if 'transaction' in data:
new_transaction: str = data['transaction']
result['transaction'] = new_transaction
if 'readTime' in data:
read_time: str = data['readTime']
result['readTime'] = read_time
return result

# https://cloud.google.com/datastore/docs/reference/data/rest/v1/ReadOptions
def _build_read_options(self,
consistency: Consistency,
newTransaction: Optional[TransactionOptions],
transaction: Optional[str]) -> Dict[str, Any]:
transaction: Optional[str],
read_time: Optional[str],
) -> Dict[str, Any]:
# TODO: expose ReadOptions directly to users
if transaction:
return {'transaction': transaction}

if newTransaction:
return {'newTransaction': newTransaction.to_repr()}

if read_time:
return {'readTime': read_time}

return {'readConsistency': consistency.value}

# https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/reserveIds
async def reserveIds(
self, keys: List[Key], database_id: str = '',
session: Optional[Session] = None,
timeout: int = 10,
timeout: Timeout = 10,
) -> None:
project = await self.project()
url = f'{self._api_root}/projects/{project}:reserveIds'
Expand All @@ -393,7 +405,7 @@ async def reserveIds(
async def rollback(
self, transaction: str,
session: Optional[Session] = None,
timeout: int = 10,
timeout: Timeout = 10,
) -> None:
project = await self.project()
url = f'{self._api_root}/projects/{project}:rollback'
Expand All @@ -417,26 +429,26 @@ async def runQuery(
self, query: BaseQuery,
explain_options: Optional[ExplainOptions] = None,
transaction: Optional[str] = None,
newTransaction: Optional[TransactionOptions] = None,
consistency: Consistency = Consistency.EVENTUAL,
read_time: Optional[str] = None,
session: Optional[Session] = None,
timeout: int = 10,
timeout: Timeout = 10,
) -> QueryResult:

project = await self.project()
url = f'{self._api_root}/projects/{project}:runQuery'

if transaction:
options = {'transaction': transaction}
else:
options = {'readConsistency': consistency.value}
read_options = self._build_read_options(
consistency, newTransaction, transaction, read_time)

payload_dict = {
'partitionId': {
'projectId': project,
'namespaceId': self.namespace,
},
query.json_key: query.to_repr(),
'readOptions': options,
'readOptions': read_options,
}

if explain_options:
Expand Down
12 changes: 10 additions & 2 deletions datastore/gcloud/aio/datastore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __eq__(self, other: Any) -> bool:


class QueryResultBatch:
# pylint: disable=too-many-instance-attributes
entity_result_kind = EntityResult

def __init__(
Expand All @@ -203,6 +204,7 @@ def __init__(
more_results: MoreResultsType = MoreResultsType.UNSPECIFIED,
skipped_cursor: str = '', skipped_results: int = 0,
snapshot_version: str = '',
read_time: Optional[str] = None,
) -> None:
self.end_cursor = end_cursor

Expand All @@ -212,6 +214,7 @@ def __init__(
self.skipped_cursor = skipped_cursor
self.skipped_results = skipped_results
self.snapshot_version = snapshot_version
self.read_time = read_time

def __eq__(self, other: Any) -> bool:
if not isinstance(other, QueryResultBatch):
Expand All @@ -224,7 +227,8 @@ def __eq__(self, other: Any) -> bool:
and self.more_results == other.more_results
and self.skipped_cursor == other.skipped_cursor
and self.skipped_results == other.skipped_results
and self.snapshot_version == other.snapshot_version,
and self.snapshot_version == other.snapshot_version
and self.read_time == other.read_time,
)

def __repr__(self) -> str:
Expand All @@ -242,12 +246,15 @@ def from_repr(cls, data: Dict[str, Any]) -> 'QueryResultBatch':
skipped_cursor = data.get('skippedCursor', '')
skipped_results = data.get('skippedResults', 0)
snapshot_version = data.get('snapshotVersion', '')
read_time = data.get('readTime')

return cls(
end_cursor, entity_result_type=entity_result_type,
entity_results=entity_results, more_results=more_results,
skipped_cursor=skipped_cursor,
skipped_results=skipped_results,
snapshot_version=snapshot_version,
read_time=read_time,
)

def to_repr(self) -> Dict[str, Any]:
Expand All @@ -262,7 +269,8 @@ def to_repr(self) -> Dict[str, Any]:
data['skippedCursor'] = self.skipped_cursor
if self.snapshot_version:
data['snapshotVersion'] = self.snapshot_version

if self.read_time:
data['readTime'] = self.read_time
return data


Expand Down
93 changes: 93 additions & 0 deletions datastore/tests/integration/smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import uuid

import pytest
Expand Down Expand Up @@ -657,3 +658,95 @@ async def test_analyze_query_explain(
finally:
for key in test_entities:
await ds.delete(key, session=s)


@pytest.mark.asyncio
async def test_lookup_with_read_time(
creds: str, kind: str, project: str) -> None:
test_value = f'test_read_time_{uuid.uuid4()}'
key = Key(project, [PathElement(kind, name=test_value)])

async with Session() as s:
ds = Datastore(project=project, service_file=creds, session=s)

# insert and read without readTime
time_before_insert = datetime.datetime.now(datetime.timezone.utc)
await ds.insert(key,
{'value': test_value, 'timestamp': 'after'},
session=s)

result = await ds.lookup([key], session=s)
assert len(result['found']) == 1
assert result['found'][0].entity.properties['value'] == test_value
assert isinstance(result['readTime'], str)

# lookup entity version w/ readTime
current_time = datetime.datetime.now(datetime.timezone.utc)
current_time_str = current_time.isoformat().replace('+00:00', 'Z')
result_with_datetime = await ds.lookup([key],
read_time=current_time_str,
session=s)
assert len(result_with_datetime.get('found', [])) == 1
assert isinstance(result_with_datetime['readTime'], str)

# lookup entity before insertion timestamp
past_time = time_before_insert - datetime.timedelta(seconds=10)
past_time_str = past_time.isoformat().replace('+00:00', 'Z')
result_past = await ds.lookup([key],
read_time=past_time_str,
session=s)
assert len(result_past.get('found', [])) == 0
assert len(result_past.get('missing', [])) == 1

await ds.delete(key, session=s)


# pylint: disable=too-many-locals
@pytest.mark.asyncio
async def test_run_query_with_read_time(
creds: str, kind: str, project: str) -> None:
test_value = f'read_time_test_{uuid.uuid4()}'

async with Session() as s:
ds = Datastore(project=project, service_file=creds, session=s)

before_insert = datetime.datetime.now(datetime.timezone.utc)
key = Key(project, [PathElement(kind, name=test_value)])
await ds.insert(key, {'test_field': test_value}, session=s)

# insert and query for entity
query = Query(
kind=kind,
query_filter=Filter(PropertyFilter(
prop='test_field',
operator=PropertyFilterOperator.EQUAL,
value=Value(test_value)
))
)
result_current = await ds.runQuery(query, session=s)

assert len(result_current.result_batch.entity_results) == 1
assert result_current.result_batch.entity_results[0].entity.properties[
'test_field'] == test_value

# query w/ readTime
current = datetime.datetime.now(datetime.timezone.utc)
current_str = current.isoformat().replace('+00:00', 'Z')
result_with_datetime = await ds.runQuery(query,
read_time=current_str,
session=s)
assert len(result_with_datetime.result_batch.entity_results) == 1

# verify readTime != empty and is a string
assert isinstance(result_with_datetime.result_batch.read_time, str)
assert result_with_datetime.result_batch.read_time is not None

# query w/ readTime before insertion time
past_time = before_insert - datetime.timedelta(seconds=10)
past_time_str = past_time.isoformat().replace('+00:00', 'Z')
result_past = await ds.runQuery(query,
read_time=past_time_str,
session=s)
assert len(result_past.result_batch.entity_results) == 0

await ds.delete(key, session=s)
25 changes: 25 additions & 0 deletions datastore/tests/unit/datastore_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from gcloud.aio.datastore import Consistency
from gcloud.aio.datastore import Datastore
from gcloud.aio.datastore import Key
from gcloud.aio.datastore import Operation
Expand All @@ -16,6 +17,30 @@ def test_make_mutation_from_value_object(key):

assert results['insert']['properties']['value'] == value.to_repr()

# pylint: disable=protected-access
@staticmethod
def test_build_read_options_priority():
ds = Datastore()
dt_str = '2025-01-01T12:00:00Z'

# transaction > readTime > consistency
result = ds._build_read_options(
Consistency.STRONG, None, 'txn123', dt_str
)
assert result == {'transaction': 'txn123'}

# readTime > consistency
result = ds._build_read_options(
Consistency.STRONG, None, None, dt_str
)
assert result == {'readTime': '2025-01-01T12:00:00Z'}

# fall back to consistency
result = ds._build_read_options(
Consistency.STRONG, None, None, None
)
assert result == {'readConsistency': 'STRONG'}

@staticmethod
@pytest.fixture(scope='session')
def key() -> Key:
Expand Down
Loading