Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
fixed:
- Added caching of saved simulations
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"microdf_python",
"plotly>=5.0.0",
"requests>=2.31.0",
"psutil>=5.9.0",
]

[project.optional-dependencies]
Expand Down
56 changes: 56 additions & 0 deletions src/policyengine/core/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
from collections import OrderedDict

import psutil

logger = logging.getLogger(__name__)

_MEMORY_THRESHOLDS_GB = [8, 16, 32]
_warned_thresholds: set[int] = set()


class LRUCache[T]:
"""Least-recently-used cache with configurable size limit and memory monitoring."""

def __init__(self, max_size: int = 100):
self._max_size = max_size
self._cache: OrderedDict[str, T] = OrderedDict()

def get(self, key: str) -> T | None:
"""Get item from cache, marking it as recently used."""
if key not in self._cache:
return None
self._cache.move_to_end(key)
return self._cache[key]

def add(self, key: str, value: T) -> None:
"""Add item to cache with LRU eviction when full."""
if key in self._cache:
self._cache.move_to_end(key)
else:
self._cache[key] = value
if len(self._cache) > self._max_size:
self._cache.popitem(last=False)

self._check_memory_usage()

def clear(self) -> None:
"""Clear all items from cache."""
self._cache.clear()
_warned_thresholds.clear()

def __len__(self) -> int:
return len(self._cache)

def _check_memory_usage(self) -> None:
"""Check memory usage and warn at threshold crossings."""
process = psutil.Process()
memory_gb = process.memory_info().rss / (1024**3)

for threshold in _MEMORY_THRESHOLDS_GB:
if memory_gb >= threshold and threshold not in _warned_thresholds:
logger.warning(
f"Memory usage has reached {memory_gb:.2f}GB (threshold: {threshold}GB). "
f"Cache contains {len(self._cache)} items."
)
_warned_thresholds.add(threshold)
8 changes: 8 additions & 0 deletions src/policyengine/core/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

from pydantic import BaseModel, Field

from .cache import LRUCache
from .dataset import Dataset
from .dynamic import Dynamic
from .policy import Policy
from .tax_benefit_model_version import TaxBenefitModelVersion

_cache: LRUCache["Simulation"] = LRUCache(max_size=100)


class Simulation(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
Expand All @@ -25,12 +28,17 @@ def run(self):
self.tax_benefit_model_version.run(self)

def ensure(self):
cached_result = _cache.get(self.id)
if cached_result:
return cached_result
try:
self.tax_benefit_model_version.load(self)
except Exception:
self.run()
self.save()

_cache.add(self.id, self)

def save(self):
"""Save the simulation's output dataset."""
self.tax_benefit_model_version.save(self)
Expand Down
145 changes: 145 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os
import tempfile

import pandas as pd
from microdf import MicroDataFrame

from policyengine.core import Simulation
from policyengine.core.cache import LRUCache
from policyengine.tax_benefit_models.uk import (
PolicyEngineUKDataset,
UKYearData,
uk_latest,
)


def test_simulation_cache_hit():
"""Test that simulation caching works with UK simulations."""
person_df = MicroDataFrame(
pd.DataFrame(
{
"person_id": [1, 2, 3],
"benunit_id": [1, 1, 2],
"household_id": [1, 1, 2],
"age": [30, 25, 40],
"employment_income": [50000, 30000, 60000],
"person_weight": [1.0, 1.0, 1.0],
}
),
weights="person_weight",
)

benunit_df = MicroDataFrame(
pd.DataFrame(
{
"benunit_id": [1, 2],
"benunit_weight": [1.0, 1.0],
}
),
weights="benunit_weight",
)

household_df = MicroDataFrame(
pd.DataFrame(
{
"household_id": [1, 2],
"household_weight": [1.0, 1.0],
}
),
weights="household_weight",
)

with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "test.h5")

dataset = PolicyEngineUKDataset(
name="Test",
description="Test dataset",
filepath=filepath,
year=2024,
data=UKYearData(
person=person_df, benunit=benunit_df, household=household_df
),
)

simulation = Simulation(
dataset=dataset,
tax_benefit_model_version=uk_latest,
output_dataset=dataset,
)

# Import the cache
from policyengine.core.simulation import _cache

# Manually add to cache (simulating what ensure() does)
_cache.add(simulation.id, simulation)

# Verify simulation is in cache
assert simulation.id in _cache._cache
assert len(_cache) >= 1

# Verify cache returns same object
cached_sim = _cache.get(simulation.id)
assert cached_sim is simulation

# Clear cache for other tests
_cache.clear()


def test_lru_cache_eviction():
"""Test that LRU cache properly evicts old items."""
cache = LRUCache[str](max_size=3)

cache.add("a", "value_a")
cache.add("b", "value_b")
cache.add("c", "value_c")

assert len(cache) == 3
assert cache.get("a") == "value_a"

# Add fourth item, should evict 'b' (least recently used)
cache.add("d", "value_d")

assert len(cache) == 3
assert cache.get("b") is None
assert cache.get("a") == "value_a"
assert cache.get("c") == "value_c"
assert cache.get("d") == "value_d"


def test_lru_cache_access_updates_order():
"""Test that accessing items updates their position in LRU order."""
cache = LRUCache[str](max_size=3)

cache.add("a", "value_a")
cache.add("b", "value_b")
cache.add("c", "value_c")

# Access 'a' to move it to most recently used
cache.get("a")

# Add fourth item, should evict 'b' (now least recently used)
cache.add("d", "value_d")

assert cache.get("a") == "value_a"
assert cache.get("b") is None
assert cache.get("c") == "value_c"
assert cache.get("d") == "value_d"


def test_lru_cache_clear():
"""Test that clearing cache works properly."""
cache = LRUCache[str](max_size=10)

cache.add("a", "value_a")
cache.add("b", "value_b")
cache.add("c", "value_c")

assert len(cache) == 3

cache.clear()

assert len(cache) == 0
assert cache.get("a") is None
assert cache.get("b") is None
assert cache.get("c") is None