Skip to content
6 changes: 4 additions & 2 deletions pybmc/bmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def predict(self, property):
- lower_df: DataFrame with columns domain_keys + ['Predicted_Lower']
- median_df: DataFrame with columns domain_keys + ['Predicted_Median']
- upper_df: DataFrame with columns domain_keys + ['Predicted_Upper']
- weights: numpy.ndarray with posterior model weight samples
"""
if self.samples is None or self.Vt_hat is None:
raise ValueError("Must call `orthogonalize()` and `train()` before predicting.")
Expand All @@ -142,7 +143,8 @@ def predict(self, property):
model_preds = df[available_models].values
domain_df = df[domain_keys].reset_index(drop=True)

rndm_m, (lower, median, upper) = rndm_m_random_calculator(model_preds, self.samples, self.Vt_hat)
rndm_m, (lower, median, upper), weights = rndm_m_random_calculator(
model_preds, self.samples, self.Vt_hat, output_weights=True)

# Build output DataFrames
lower_df = domain_df.copy()
Expand All @@ -155,7 +157,7 @@ def predict(self, property):
upper_df = domain_df.copy()
upper_df["Predicted_Upper"] = upper

return rndm_m, lower_df, median_df, upper_df
return rndm_m, lower_df, median_df, upper_df, weights

def evaluate(self, domain_filter=None):
"""
Expand Down
10 changes: 7 additions & 3 deletions pybmc/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def coverage(percentiles, rndm_m, models_output, truth_column):
return coverage_results


def rndm_m_random_calculator(filtered_model_predictions, samples, Vt_hat):
def rndm_m_random_calculator(filtered_model_predictions, samples, Vt_hat, output_weights=False):
"""
Generates posterior predictive samples and credible intervals.

Expand All @@ -50,6 +50,7 @@ def rndm_m_random_calculator(filtered_model_predictions, samples, Vt_hat):
tuple[numpy.ndarray, list[numpy.ndarray]]:
- `rndm_m` (numpy.ndarray): Posterior predictive samples.
- `[lower, median, upper]` (list[numpy.ndarray]): Credible interval arrays.
- `model_weights_random` (numpy.ndarray): Posterior model weight samples
"""
np.random.seed(142858)
rng = np.random.default_rng()
Expand Down Expand Up @@ -80,5 +81,8 @@ def rndm_m_random_calculator(filtered_model_predictions, samples, Vt_hat):
lower_radius = np.percentile(rndm_m, 2.5, axis=0)
median_radius = np.percentile(rndm_m, 50, axis=0)
upper_radius = np.percentile(rndm_m, 97.5, axis=0)

return rndm_m, [lower_radius, median_radius, upper_radius]
if output_weights:
return rndm_m, [lower_radius, median_radius, upper_radius], model_weights_random
else:
return rndm_m, [lower_radius, median_radius, upper_radius]

5 changes: 3 additions & 2 deletions tests/test_bmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_predict(self):
self.bmc.train()
# Use all rows for prediction input
X = self.df[["x", "y", "model1", "model2", "model3"]].copy()
rndm_m, lower_df, median_df, upper_df = self.bmc.predict(self.property)
rndm_m, lower_df, median_df, upper_df, weights = self.bmc.predict(self.property)

self.assertEqual(rndm_m.shape[1], len(X))
self.assertIn("Predicted_Lower", lower_df.columns)
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_bmc_predict(self):
bmc.train()

# Perform prediction using property name
rndm_m, lower_df, median_df, upper_df = bmc.predict("property")
rndm_m, lower_df, median_df, upper_df, weights = bmc.predict("property")

self.assertIsNotNone(rndm_m)
self.assertIsInstance(lower_df, pd.DataFrame)
Expand All @@ -169,6 +169,7 @@ def test_bmc_predict(self):
self.assertFalse(lower_df.empty)
self.assertFalse(median_df.empty)
self.assertFalse(upper_df.empty)
self.assertIsNotNone(weights)
# Check domain columns are present
for col in ["N", "Z"]:
self.assertIn(col, lower_df.columns)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def mock_hdf_reader(file, key):
self.assertIsNotNone(bmc.Vt_hat)

# Step 5: Predict on all domain points (including those without truth)
rndm_m, lower_df, median_df, upper_df = bmc.predict("BE")
rndm_m, lower_df, median_df, upper_df, weights= bmc.predict("BE")

# Verify predictions
self.assertEqual(rndm_m.shape[1], 6, "Should have predictions for all 6 domain points")
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_bmc_workflow_with_smaller_truth_domain_csv(self, mock_read_csv, mock_ex
bmc.train()

# Predict on all points
rndm_m, lower_df, median_df, upper_df = bmc.predict("BE")
rndm_m, lower_df, median_df, upper_df, weights = bmc.predict("BE")

# Verify predictions cover all domain points
self.assertEqual(len(lower_df), 6)
Expand Down