Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions servicex/dataset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ async def as_signed_urls_async(
provided_progress: Optional[Progress] = None,
return_exceptions: bool = False,
overall_progress: bool = False,
) -> List[Union[TransformedResults, BaseException]]:
# preflight auth
) -> List[Union[TransformedResults, Exception]]:
if self.datasets:
await self.datasets[0].servicex._get_authorization()
await self.datasets[0].servicex.get_code_generators_async()
with ExpandableProgress(
display_progress, provided_progress, overall_progress=overall_progress
) as progress:
Expand All @@ -80,9 +80,13 @@ async def as_signed_urls_async(
)
for d in self.datasets
]
return await asyncio.gather(
*self.tasks, return_exceptions=return_exceptions
progress.disable_refresh()
gather_future = asyncio.ensure_future(
asyncio.gather(*self.tasks, return_exceptions=return_exceptions)
)
await asyncio.sleep(0)
progress.enable_refresh()
return await gather_future

as_signed_urls = make_sync(as_signed_urls_async)

Expand All @@ -92,10 +96,10 @@ async def as_files_async(
provided_progress: Optional[Progress] = None,
return_exceptions: bool = False,
overall_progress: bool = False,
) -> List[Union[TransformedResults, BaseException]]:
# preflight auth
) -> List[Union[TransformedResults, Exception]]:
if self.datasets:
await self.datasets[0].servicex._get_authorization()
await self.datasets[0].servicex.get_code_generators_async()
with ExpandableProgress(
display_progress, provided_progress, overall_progress=overall_progress
) as progress:
Expand All @@ -105,8 +109,12 @@ async def as_files_async(
)
for d in self.datasets
]
return await asyncio.gather(
*self.tasks, return_exceptions=return_exceptions
progress.disable_refresh()
gather_future = asyncio.ensure_future(
asyncio.gather(*self.tasks, return_exceptions=return_exceptions)
)
await asyncio.sleep(0)
progress.enable_refresh()
return await gather_future

as_files = make_sync(as_files_async)
14 changes: 14 additions & 0 deletions servicex/expandable_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,20 @@ def refresh(self):
if self.progress is not None:
self.progress.refresh()

def disable_refresh(self) -> None:
"""Suppress per-call refreshes to avoid repeated re-renders during bulk task
registration. Call enable_refresh() to flush a single render afterward."""
if self.progress is not None:
self._saved_refresh = self.progress.refresh
self.progress.refresh = lambda: None

def enable_refresh(self) -> None:
"""Re-enable rendering and emit a single refresh to show all registered tasks."""
if self.progress is not None and hasattr(self, "_saved_refresh"):
self.progress.refresh = self._saved_refresh
del self._saved_refresh
self.progress.refresh()


class TranformStatusProgress(Progress):
def get_renderables(self):
Expand Down
94 changes: 48 additions & 46 deletions servicex/query_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from datetime import datetime, timezone
from filelock import FileLock
from tinydb import TinyDB, Query, where
import logging

from servicex.configuration import Configuration
from servicex.models import TransformRequest, TransformStatus, TransformedResults
Expand All @@ -41,9 +42,13 @@ class CacheException(Exception):
pass


logger = logging.getLogger(__name__)


class QueryCache:
def __init__(self, config: Configuration):
self.config = config
self._mem_cache: Optional[List[dict]] = None
if self.config.cache_path is not None:
Path(self.config.cache_path).mkdir(parents=True, exist_ok=True)
Path(self.config.cache_path + "/.servicex").mkdir(
Expand All @@ -59,6 +64,17 @@ def __init__(self, config: Configuration):
def close(self):
self.db.close()

def _load_all(self) -> List[dict]:
"""Return all DB records, using an in-memory cache to avoid repeated JSON parses."""
if self._mem_cache is None:
with self.lock:
self._mem_cache = self.db.all()
return self._mem_cache

def _invalidate(self) -> None:
"""Invalidate the in-memory cache after a write."""
self._mem_cache = None

def transformed_results(
self,
transform: TransformRequest,
Expand Down Expand Up @@ -87,51 +103,41 @@ def cache_transform(self, record: TransformedResults):
self.db.upsert(
json.loads(record.model_dump_json()), transforms.hash == record.hash
)
self._invalidate()

def update_record(self, record: TransformedResults):
transforms = Query()
with self.lock:
self.db.update(
json.loads(record.model_dump_json()), transforms.hash == record.hash
)
self._invalidate()

def contains_hash(self, hash: str) -> bool:
"""
Check if the cache has completed records for a hash
"""
transforms = Query()
with self.lock:
records = self.db.search(
(transforms.hash == hash) & ~(transforms.status == "SUBMITTED")
)
return len(records) > 0
return any(
doc.get("hash") == hash and doc.get("status") != "SUBMITTED"
for doc in self._load_all()
)

def is_transform_request_submitted(self, hash_value: str) -> bool:
"""
Returns True if request is submitted
Returns False if the request is not in the cache at all
or not submitted
"""
transform = Query()
with self.lock:
records = self.db.search((transform.hash == hash_value))

records = [doc for doc in self._load_all() if doc.get("hash") == hash_value]
if not records:
return False

if "status" in records[0] and records[0]["status"] == "SUBMITTED":
return True
return False
return records[0].get("status") == "SUBMITTED"

def get_transform_request_id(self, hash_value: str) -> Optional[str]:
"""
Return the request id of cached record
"""
transform = Query()

with self.lock:
records = self.db.search(transform.hash == hash_value)

records = [doc for doc in self._load_all() if doc.get("hash") == hash_value]
if not records or "request_id" not in records[0]:
raise CacheException("Request Id not found")
return records[0]["request_id"]
Expand All @@ -145,6 +151,7 @@ def update_transform_status(self, hash_value: str, status: str) -> None:
self.db.upsert(
{"hash": hash_value, "status": status}, transform.hash == hash_value
)
self._invalidate()

def update_transform_request_id(self, hash_value: str, request_id: str) -> None:
"""
Expand All @@ -156,6 +163,7 @@ def update_transform_request_id(self, hash_value: str, request_id: str) -> None:
{"hash": hash_value, "request_id": request_id},
transform.hash == hash_value,
)
self._invalidate()

def cache_submitted_transform(
self, transform: TransformRequest, request_id: str
Expand All @@ -174,16 +182,17 @@ def cache_submitted_transform(
transforms = Query()
with self.lock:
self.db.upsert(record, transforms.hash == record["hash"])
self._invalidate()

def get_transform_by_hash(self, hash: str) -> Optional[TransformedResults]:
"""
Returns completed transformations by hash
"""
transforms = Query()
with self.lock:
records = records = self.db.search(
(transforms.hash == hash) & ~(transforms.status == "SUBMITTED")
)
records = [
doc
for doc in self._load_all()
if doc.get("hash") == hash and doc.get("status") != "SUBMITTED"
]

if not records:
return None
Expand All @@ -199,10 +208,9 @@ def get_transform_by_request_id(
"""
Returns completed transformed results using a request id
"""
transforms = Query()

with self.lock:
records = self.db.search(transforms.request_id == request_id)
records = [
doc for doc in self._load_all() if doc.get("request_id") == request_id
]

if not records:
return None
Expand All @@ -220,33 +228,27 @@ def cache_path_for_transform(self, transform_status: TransformStatus) -> Path:
return result

def cached_queries(self) -> List[TransformedResults]:
transforms = Query()

with self.lock:
result = [
TransformedResults(**doc)
for doc in self.db.search(
transforms.request_id.exists() & ~(transforms.status == "SUBMITTED")
)
]
return result
return [
TransformedResults(**doc)
for doc in self._load_all()
if "request_id" in doc and doc.get("status") != "SUBMITTED"
]

def queries_in_state(self, state: str) -> List[dict]:
"""Return all transform records in a given state."""
transforms = Query()
with self.lock:
return [
doc
for doc in self.db.search(
(transforms.status == "SUBMITTED") & transforms.request_id.exists()
)
]
return [
doc
for doc in self._load_all()
if doc.get("status") == state and "request_id" in doc
]

def delete_record_by_request_id(self, request_id: str):
with self.lock:
self.db.remove(where("request_id") == request_id)
self._invalidate()

def delete_record_by_hash(self, hash: str):
transforms = Query()
with self.lock:
self.db.remove(transforms.hash == hash)
self._invalidate()
3 changes: 3 additions & 0 deletions tests/test_dataset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ async def test_as_signed_urls(mocker, transformed_result):
ds1 = mocker.Mock()
ds1.as_signed_urls_async = AsyncMock(return_value=transformed_result)
ds1.servicex._get_authorization = AsyncMock()
ds1.servicex.get_code_generators_async = AsyncMock(return_value={})

ds2 = mocker.Mock()
ds2.as_signed_urls_async = AsyncMock(
Expand All @@ -67,6 +68,7 @@ async def test_as_files(mocker, transformed_result):
ds1 = mocker.Mock()
ds1.as_files_async = AsyncMock(return_value=transformed_result)
ds1.servicex._get_authorization = AsyncMock()
ds1.servicex.get_code_generators_async = AsyncMock(return_value={})

ds2 = mocker.Mock()
ds2.as_files_async = AsyncMock(
Expand All @@ -86,6 +88,7 @@ async def test_failure(mocker, transformed_result):
ds1 = mocker.Mock()
ds1.as_signed_urls_async = AsyncMock(return_value=transformed_result)
ds1.servicex._get_authorization = AsyncMock()
ds1.servicex.get_code_generators_async = AsyncMock(return_value={})

ds2 = mocker.Mock()
ds2.as_signed_urls_async = AsyncMock(side_effect=ServiceXException("dummy"))
Expand Down
Loading