diff --git a/servicex/dataset_group.py b/servicex/dataset_group.py index 5b3f5cda..c6b2d9c9 100644 --- a/servicex/dataset_group.py +++ b/servicex/dataset_group.py @@ -65,10 +65,10 @@ async def as_signed_urls_async( provided_progress: Optional[Progress] = None, return_exceptions: bool = False, overall_progress: bool = False, - ) -> List[Union[TransformedResults, BaseException]]: - # preflight auth + ) -> List[Union[TransformedResults, Exception]]: if self.datasets: await self.datasets[0].servicex._get_authorization() + await self.datasets[0].servicex.get_code_generators_async() with ExpandableProgress( display_progress, provided_progress, overall_progress=overall_progress ) as progress: @@ -80,9 +80,13 @@ async def as_signed_urls_async( ) for d in self.datasets ] - return await asyncio.gather( - *self.tasks, return_exceptions=return_exceptions + progress.disable_refresh() + gather_future = asyncio.ensure_future( + asyncio.gather(*self.tasks, return_exceptions=return_exceptions) ) + await asyncio.sleep(0) + progress.enable_refresh() + return await gather_future as_signed_urls = make_sync(as_signed_urls_async) @@ -92,10 +96,10 @@ async def as_files_async( provided_progress: Optional[Progress] = None, return_exceptions: bool = False, overall_progress: bool = False, - ) -> List[Union[TransformedResults, BaseException]]: - # preflight auth + ) -> List[Union[TransformedResults, Exception]]: if self.datasets: await self.datasets[0].servicex._get_authorization() + await self.datasets[0].servicex.get_code_generators_async() with ExpandableProgress( display_progress, provided_progress, overall_progress=overall_progress ) as progress: @@ -105,8 +109,12 @@ async def as_files_async( ) for d in self.datasets ] - return await asyncio.gather( - *self.tasks, return_exceptions=return_exceptions + progress.disable_refresh() + gather_future = asyncio.ensure_future( + asyncio.gather(*self.tasks, return_exceptions=return_exceptions) ) + await asyncio.sleep(0) + progress.enable_refresh() + return await gather_future as_files = make_sync(as_files_async) diff --git a/servicex/expandable_progress.py b/servicex/expandable_progress.py index bed32839..6d7426ab 100644 --- a/servicex/expandable_progress.py +++ b/servicex/expandable_progress.py @@ -229,6 +229,20 @@ def refresh(self): if self.progress is not None: self.progress.refresh() + def disable_refresh(self) -> None: + """Suppress per-call refreshes to avoid repeated re-renders during bulk task + registration. Call enable_refresh() to flush a single render afterward.""" + if self.progress is not None: + self._saved_refresh = self.progress.refresh + self.progress.refresh = lambda: None + + def enable_refresh(self) -> None: + """Re-enable rendering and emit a single refresh to show all registered tasks.""" + if self.progress is not None and hasattr(self, "_saved_refresh"): + self.progress.refresh = self._saved_refresh + del self._saved_refresh + self.progress.refresh() + class TranformStatusProgress(Progress): def get_renderables(self): diff --git a/servicex/query_cache.py b/servicex/query_cache.py index c8066ee1..5778d0b2 100644 --- a/servicex/query_cache.py +++ b/servicex/query_cache.py @@ -32,6 +32,7 @@ from datetime import datetime, timezone from filelock import FileLock from tinydb import TinyDB, Query, where +import logging from servicex.configuration import Configuration from servicex.models import TransformRequest, TransformStatus, TransformedResults @@ -41,9 +42,13 @@ class CacheException(Exception): pass +logger = logging.getLogger(__name__) + + class QueryCache: def __init__(self, config: Configuration): self.config = config + self._mem_cache: Optional[List[dict]] = None if self.config.cache_path is not None: Path(self.config.cache_path).mkdir(parents=True, exist_ok=True) Path(self.config.cache_path + "/.servicex").mkdir( @@ -59,6 +64,17 @@ def __init__(self, config: Configuration): def close(self): self.db.close() + def _load_all(self) -> List[dict]: + """Return all DB records, using an in-memory cache to avoid repeated JSON parses.""" + if self._mem_cache is None: + with self.lock: + self._mem_cache = self.db.all() + return self._mem_cache + + def _invalidate(self) -> None: + """Invalidate the in-memory cache after a write.""" + self._mem_cache = None + def transformed_results( self, transform: TransformRequest, @@ -87,6 +103,7 @@ def cache_transform(self, record: TransformedResults): self.db.upsert( json.loads(record.model_dump_json()), transforms.hash == record.hash ) + self._invalidate() def update_record(self, record: TransformedResults): transforms = Query() @@ -94,17 +111,16 @@ def update_record(self, record: TransformedResults): self.db.update( json.loads(record.model_dump_json()), transforms.hash == record.hash ) + self._invalidate() def contains_hash(self, hash: str) -> bool: """ Check if the cache has completed records for a hash """ - transforms = Query() - with self.lock: - records = self.db.search( - (transforms.hash == hash) & ~(transforms.status == "SUBMITTED") - ) - return len(records) > 0 + return any( + doc.get("hash") == hash and doc.get("status") != "SUBMITTED" + for doc in self._load_all() + ) def is_transform_request_submitted(self, hash_value: str) -> bool: """ @@ -112,26 +128,16 @@ def is_transform_request_submitted(self, hash_value: str) -> bool: Returns False if the request is not in the cache at all or not submitted """ - transform = Query() - with self.lock: - records = self.db.search((transform.hash == hash_value)) - + records = [doc for doc in self._load_all() if doc.get("hash") == hash_value] if not records: return False - - if "status" in records[0] and records[0]["status"] == "SUBMITTED": - return True - return False + return records[0].get("status") == "SUBMITTED" def get_transform_request_id(self, hash_value: str) -> Optional[str]: """ Return the request id of cached record """ - transform = Query() - - with self.lock: - records = self.db.search(transform.hash == hash_value) - + records = [doc for doc in self._load_all() if doc.get("hash") == hash_value] if not records or "request_id" not in records[0]: raise CacheException("Request Id not found") return records[0]["request_id"] @@ -145,6 +151,7 @@ def update_transform_status(self, hash_value: str, status: str) -> None: self.db.upsert( {"hash": hash_value, "status": status}, transform.hash == hash_value ) + self._invalidate() def update_transform_request_id(self, hash_value: str, request_id: str) -> None: """ @@ -156,6 +163,7 @@ def update_transform_request_id(self, hash_value: str, request_id: str) -> None: {"hash": hash_value, "request_id": request_id}, transform.hash == hash_value, ) + self._invalidate() def cache_submitted_transform( self, transform: TransformRequest, request_id: str @@ -174,16 +182,17 @@ def cache_submitted_transform( transforms = Query() with self.lock: self.db.upsert(record, transforms.hash == record["hash"]) + self._invalidate() def get_transform_by_hash(self, hash: str) -> Optional[TransformedResults]: """ Returns completed transformations by hash """ - transforms = Query() - with self.lock: - records = records = self.db.search( - (transforms.hash == hash) & ~(transforms.status == "SUBMITTED") - ) + records = [ + doc + for doc in self._load_all() + if doc.get("hash") == hash and doc.get("status") != "SUBMITTED" + ] if not records: return None @@ -199,10 +208,9 @@ def get_transform_by_request_id( """ Returns completed transformed results using a request id """ - transforms = Query() - - with self.lock: - records = self.db.search(transforms.request_id == request_id) + records = [ + doc for doc in self._load_all() if doc.get("request_id") == request_id + ] if not records: return None @@ -220,33 +228,27 @@ def cache_path_for_transform(self, transform_status: TransformStatus) -> Path: return result def cached_queries(self) -> List[TransformedResults]: - transforms = Query() - - with self.lock: - result = [ - TransformedResults(**doc) - for doc in self.db.search( - transforms.request_id.exists() & ~(transforms.status == "SUBMITTED") - ) - ] - return result + return [ + TransformedResults(**doc) + for doc in self._load_all() + if "request_id" in doc and doc.get("status") != "SUBMITTED" + ] def queries_in_state(self, state: str) -> List[dict]: """Return all transform records in a given state.""" - transforms = Query() - with self.lock: - return [ - doc - for doc in self.db.search( - (transforms.status == "SUBMITTED") & transforms.request_id.exists() - ) - ] + return [ + doc + for doc in self._load_all() + if doc.get("status") == state and "request_id" in doc + ] def delete_record_by_request_id(self, request_id: str): with self.lock: self.db.remove(where("request_id") == request_id) + self._invalidate() def delete_record_by_hash(self, hash: str): transforms = Query() with self.lock: self.db.remove(transforms.hash == hash) + self._invalidate() diff --git a/tests/test_dataset_group.py b/tests/test_dataset_group.py index 53c54646..6fa41d17 100644 --- a/tests/test_dataset_group.py +++ b/tests/test_dataset_group.py @@ -48,6 +48,7 @@ async def test_as_signed_urls(mocker, transformed_result): ds1 = mocker.Mock() ds1.as_signed_urls_async = AsyncMock(return_value=transformed_result) ds1.servicex._get_authorization = AsyncMock() + ds1.servicex.get_code_generators_async = AsyncMock(return_value={}) ds2 = mocker.Mock() ds2.as_signed_urls_async = AsyncMock( @@ -67,6 +68,7 @@ async def test_as_files(mocker, transformed_result): ds1 = mocker.Mock() ds1.as_files_async = AsyncMock(return_value=transformed_result) ds1.servicex._get_authorization = AsyncMock() + ds1.servicex.get_code_generators_async = AsyncMock(return_value={}) ds2 = mocker.Mock() ds2.as_files_async = AsyncMock( @@ -86,6 +88,7 @@ async def test_failure(mocker, transformed_result): ds1 = mocker.Mock() ds1.as_signed_urls_async = AsyncMock(return_value=transformed_result) ds1.servicex._get_authorization = AsyncMock() + ds1.servicex.get_code_generators_async = AsyncMock(return_value={}) ds2 = mocker.Mock() ds2.as_signed_urls_async = AsyncMock(side_effect=ServiceXException("dummy")) diff --git a/tests/test_query_cache.py b/tests/test_query_cache.py index 010928db..9a26bdd5 100644 --- a/tests/test_query_cache.py +++ b/tests/test_query_cache.py @@ -117,7 +117,8 @@ def test_cache_transform(transform_request, completed_status): assert len(cache.cached_queries()) == 1 - # forcefully create a duplicate record + # forcefully create a duplicate record directly in the DB (bypasses _invalidate), + # so explicitly clear the in-memory cache to reflect the new state record = json.loads( cache.transformed_results( transform=transform_request, @@ -130,6 +131,7 @@ def test_cache_transform(transform_request, completed_status): record["hash"] = transform_request.compute_hash() record["status"] = "COMPLETE" cache.db.insert(record) + cache._invalidate() with pytest.raises(CacheException): cache.get_transform_by_hash(transform_request.compute_hash()) @@ -285,6 +287,60 @@ def test_get_transform_request_status(transform_request, completed_status): cache.close() +def test_mem_cache_reads_db_once(transform_request, completed_status, mocker): + """Multiple consecutive reads should only parse the DB file once.""" + with tempfile.TemporaryDirectory() as temp_dir: + config = Configuration(cache_path=temp_dir, api_endpoints=[]) # type: ignore + cache = QueryCache(config) + cache.cache_transform( + cache.transformed_results( + transform=transform_request, + completed_status=completed_status, + data_dir="/foo/bar", + file_list=file_uris, + signed_urls=[], + ) + ) + # Spy on the underlying db.all() call + spy = mocker.spy(cache.db, "all") + + cache.get_transform_by_hash(transform_request.compute_hash()) + cache.get_transform_by_hash(transform_request.compute_hash()) + cache.contains_hash(transform_request.compute_hash()) + cache.cached_queries() + + # Four reads but db.all() should only have been called once + assert spy.call_count == 1 + cache.close() + + +def test_mem_cache_invalidated_on_write(transform_request, completed_status, mocker): + """Writing a record invalidates the in-memory cache so the next read reloads.""" + with tempfile.TemporaryDirectory() as temp_dir: + config = Configuration(cache_path=temp_dir, api_endpoints=[]) # type: ignore + cache = QueryCache(config) + cache.cache_transform( + cache.transformed_results( + transform=transform_request, + completed_status=completed_status, + data_dir="/foo/bar", + file_list=file_uris, + signed_urls=[], + ) + ) + spy = mocker.spy(cache.db, "all") + + # First read populates the cache + cache.get_transform_by_hash(transform_request.compute_hash()) + assert spy.call_count == 1 + + # A write should invalidate, so the next read hits the DB again + cache.cache_submitted_transform(transform_request, "new-request-id") + cache.get_transform_by_hash(transform_request.compute_hash()) + assert spy.call_count == 2 + cache.close() + + def test_cache_queries_in_state(transform_request): with tempfile.TemporaryDirectory() as temp_dir: config = Configuration(cache_path=temp_dir, api_endpoints=[]) # type: ignore @@ -302,3 +358,46 @@ def test_cache_queries_in_state(transform_request): ) cache.close() + + +def test_update_record(transform_request, completed_status): + with tempfile.TemporaryDirectory() as temp_dir: + config = Configuration(cache_path=temp_dir, api_endpoints=[]) # type: ignore + cache = QueryCache(config) + original = cache.transformed_results( + transform=transform_request, + completed_status=completed_status, + data_dir="/foo/bar", + file_list=file_uris, + signed_urls=[], + ) + cache.cache_transform(original) + + # Populate the in-memory cache + assert cache.get_transform_by_hash(transform_request.compute_hash()) + + updated = original.model_copy(update={"data_dir": "/foo/updated"}) + cache.update_record(updated) + + # Cache should be invalidated; next read reflects the update + result = cache.get_transform_by_hash(transform_request.compute_hash()) + assert result is not None + assert result.data_dir == "/foo/updated" + cache.close() + + +def test_update_transform_request_id(transform_request): + with tempfile.TemporaryDirectory() as temp_dir: + config = Configuration(cache_path=temp_dir, api_endpoints=[]) # type: ignore + cache = QueryCache(config) + hash_value = transform_request.compute_hash() + cache.cache_submitted_transform(transform_request, "old-request-id") + + # Populate the in-memory cache + assert cache.get_transform_request_id(hash_value) == "old-request-id" + + cache.update_transform_request_id(hash_value, "new-request-id") + + # Cache should be invalidated; next read reflects the new request id + assert cache.get_transform_request_id(hash_value) == "new-request-id" + cache.close()