Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.9'
- uses: pre-commit/action@v3.0.1

verify-output:
Expand Down
78 changes: 42 additions & 36 deletions causing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,65 +51,56 @@ def __post_init__(self):
def compute(
self,
xdat: np.array,
# fix a yval
fixed_yval: np.array = None,
fixed_yind: int = None,
# fix an arbitrary node going into a yval
fixed_from_ind: int = None,
fixed_to_yind: int = None,
fixed_vals: list = None,
# override default parameter values
parameters: dict[str, float] = {},
) -> np.array:
"""Compute y values for given x values

"""Compute y values for given x values (Optimized Version)
xdat: m rows, tau columns
returns: n rows, tau columns
"""

assert xdat.ndim == 2, f"xdat must be m*tau (is {xdat.ndim}-dimensional)"
assert xdat.shape[0] == self.mdim, f"xdat must be m*tau (is {xdat.shape})"
tau = xdat.shape[1]
parameters = self.parameters | parameters

yhat = np.array([[float("nan")] * tau] * len(self.yvars))
# Use np.full for clarity and potential small performance gain
yhat = np.full((self.ndim, tau), np.nan)

for i, eq in enumerate(self._model_lam):

if fixed_yind == i:
yhat[i, :] = fixed_yval
else:
eq_inputs = np.array(
[[*xval, *yval] for xval, yval in zip(xdat.T, yhat.T)]
)
# 1. Build inputs (Vectorized)
# Use np.vstack for efficient vertical stacking.
# `yhat` will have NaNs for unsolved variables, which is correct.
eq_inputs = np.vstack([xdat, yhat])

# 2. Apply fixed values if needed (Vectorized)
if fixed_to_yind == i:
eq_inputs[:, fixed_from_ind] = fixed_vals
# Directly modify the correct "row" in the input matrix.
# This is much cleaner and faster.
eq_inputs[fixed_from_ind, :] = fixed_vals

# 3. Evaluate equation (Vectorized)
try:
# print(f"Comuting variable: {self.yvars[i]}")
# yhat[i] = np.array(
# [eq(*eq_in, *parameters.values()) for eq_in in eq_inputs],
# dtype=np.float64,
# )
np.seterr(under="ignore")
computed_yvars = []
for eq_in in eq_inputs:
computed_yvars.append(eq(*eq_in, *parameters.values()))

yhat[i] = np.array(
computed_yvars,
dtype=np.float64,
)

# This is the core optimization. We unpack the rows of `eq_inputs`
# as separate arguments into the lambdified function. NumPy will
# then compute the results for all `tau` columns at once.
yhat[i, :] = eq(*eq_inputs, *parameters.values())

except Exception as e:
# for eq_in in eq_inputs:
# print("--", self.yvars[i])
# for var, val in zip(
# self.vars + list(parameters.keys()),
# list(eq_in) + list(parameters.values()),
# ):
# print(var, "=", val)
# eq(*eq_in, *parameters.values())
raise NumericModelError(
f"Failed to compute model value for yvar {self.yvars[i]}: {e}"
) from e
assert yhat.shape == (self.ndim, tau)

return yhat

def calc_effects(self, xdat: np.array, xdat_mean=None, yhat_mean=None):
Expand Down Expand Up @@ -198,10 +189,25 @@ def calc_effects(self, xdat: np.array, xdat_mean=None, yhat_mean=None):

@cached_property
def _model_lam(self) -> Iterable[Callable]:
return [
sympy.lambdify(self.vars + list(self.parameters), eq)
for eq in self.equations
]
"""Create lambdified equations with NumPy-compatible functions."""
lambdas = []
ordered_vars = self.vars + list(self.parameters.keys())

# Define placeholder for vectorized max function
vectorized_max = sympy.Function("vectorized_max")

# Define custom translation mapping
custom_modules = [{"vectorized_max": np.maximum}, "numpy"]

for i, eq in enumerate(self.equations):
# Replace sympy.Max with our placeholder
fixed_eq = eq.subs(sympy.Max, vectorized_max)

# Lambdify with custom NumPy mapping
lam = sympy.lambdify(ordered_vars, fixed_eq, modules=custom_modules)
lambdas.append(lam)

return lambdas

@cached_property
def final_ind(self):
Expand Down