diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..31ecc738 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - Added caching of saved simulations \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9398ba49..90e642d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "microdf_python", "plotly>=5.0.0", "requests>=2.31.0", + "psutil>=5.9.0", ] [project.optional-dependencies] diff --git a/src/policyengine/core/cache.py b/src/policyengine/core/cache.py new file mode 100644 index 00000000..44de06e3 --- /dev/null +++ b/src/policyengine/core/cache.py @@ -0,0 +1,56 @@ +import logging +from collections import OrderedDict + +import psutil + +logger = logging.getLogger(__name__) + +_MEMORY_THRESHOLDS_GB = [8, 16, 32] +_warned_thresholds: set[int] = set() + + +class LRUCache[T]: + """Least-recently-used cache with configurable size limit and memory monitoring.""" + + def __init__(self, max_size: int = 100): + self._max_size = max_size + self._cache: OrderedDict[str, T] = OrderedDict() + + def get(self, key: str) -> T | None: + """Get item from cache, marking it as recently used.""" + if key not in self._cache: + return None + self._cache.move_to_end(key) + return self._cache[key] + + def add(self, key: str, value: T) -> None: + """Add item to cache with LRU eviction when full.""" + if key in self._cache: + self._cache.move_to_end(key) + else: + self._cache[key] = value + if len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + self._check_memory_usage() + + def clear(self) -> None: + """Clear all items from cache.""" + self._cache.clear() + _warned_thresholds.clear() + + def __len__(self) -> int: + return len(self._cache) + + def _check_memory_usage(self) -> None: + """Check memory usage and warn at threshold crossings.""" + process = psutil.Process() + memory_gb = process.memory_info().rss / (1024**3) + + for threshold in _MEMORY_THRESHOLDS_GB: + if memory_gb >= threshold and threshold not in _warned_thresholds: + logger.warning( + f"Memory usage has reached {memory_gb:.2f}GB (threshold: {threshold}GB). " + f"Cache contains {len(self._cache)} items." + ) + _warned_thresholds.add(threshold) diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 0a5c106d..0c92382c 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -3,11 +3,14 @@ from pydantic import BaseModel, Field +from .cache import LRUCache from .dataset import Dataset from .dynamic import Dynamic from .policy import Policy from .tax_benefit_model_version import TaxBenefitModelVersion +_cache: LRUCache["Simulation"] = LRUCache(max_size=100) + class Simulation(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) @@ -25,12 +28,17 @@ def run(self): self.tax_benefit_model_version.run(self) def ensure(self): + cached_result = _cache.get(self.id) + if cached_result: + return cached_result try: self.tax_benefit_model_version.load(self) except Exception: self.run() self.save() + _cache.add(self.id, self) + def save(self): """Save the simulation's output dataset.""" self.tax_benefit_model_version.save(self) diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 00000000..cb16ef65 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,145 @@ +import os +import tempfile + +import pandas as pd +from microdf import MicroDataFrame + +from policyengine.core import Simulation +from policyengine.core.cache import LRUCache +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest, +) + + +def test_simulation_cache_hit(): + """Test that simulation caching works with UK simulations.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test.h5") + + dataset = PolicyEngineUKDataset( + name="Test", + description="Test dataset", + filepath=filepath, + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, + ) + + # Import the cache + from policyengine.core.simulation import _cache + + # Manually add to cache (simulating what ensure() does) + _cache.add(simulation.id, simulation) + + # Verify simulation is in cache + assert simulation.id in _cache._cache + assert len(_cache) >= 1 + + # Verify cache returns same object + cached_sim = _cache.get(simulation.id) + assert cached_sim is simulation + + # Clear cache for other tests + _cache.clear() + + +def test_lru_cache_eviction(): + """Test that LRU cache properly evicts old items.""" + cache = LRUCache[str](max_size=3) + + cache.add("a", "value_a") + cache.add("b", "value_b") + cache.add("c", "value_c") + + assert len(cache) == 3 + assert cache.get("a") == "value_a" + + # Add fourth item, should evict 'b' (least recently used) + cache.add("d", "value_d") + + assert len(cache) == 3 + assert cache.get("b") is None + assert cache.get("a") == "value_a" + assert cache.get("c") == "value_c" + assert cache.get("d") == "value_d" + + +def test_lru_cache_access_updates_order(): + """Test that accessing items updates their position in LRU order.""" + cache = LRUCache[str](max_size=3) + + cache.add("a", "value_a") + cache.add("b", "value_b") + cache.add("c", "value_c") + + # Access 'a' to move it to most recently used + cache.get("a") + + # Add fourth item, should evict 'b' (now least recently used) + cache.add("d", "value_d") + + assert cache.get("a") == "value_a" + assert cache.get("b") is None + assert cache.get("c") == "value_c" + assert cache.get("d") == "value_d" + + +def test_lru_cache_clear(): + """Test that clearing cache works properly.""" + cache = LRUCache[str](max_size=10) + + cache.add("a", "value_a") + cache.add("b", "value_b") + cache.add("c", "value_c") + + assert len(cache) == 3 + + cache.clear() + + assert len(cache) == 0 + assert cache.get("a") is None + assert cache.get("b") is None + assert cache.get("c") is None