From ab5c843f3aaed4be8b10a25e82166e8a1d62e71c Mon Sep 17 00:00:00 2001 From: NglQ <93099056+NglQ@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:21:47 +0200 Subject: [PATCH] Update classifiers.py Robust probability computation of weights --- error_parity/classifiers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/error_parity/classifiers.py b/error_parity/classifiers.py index 3ffd73e..507191b 100644 --- a/error_parity/classifiers.py +++ b/error_parity/classifiers.py @@ -279,11 +279,17 @@ def find_weights_given_two_points( f"should be 1!" ) - if not all(np.isclose(target_point, all_weights @ all_points)): + if not all(np.isclose(target_point, all_weights @ all_points, atol=1e-5)): raise RuntimeError( f"Triangulation of target point failed. " f"Target was {target_point}; got {all_weights @ all_points}." ) + + if (all_weights < 0).any(): + all_weights = np.asarray(all_weights) + x_max = np.amax(all_weights, axis=0, keepdims=True) + exp_x_shifted = np.exp(all_weights - x_max) + all_weights = exp_x_shifted / np.sum(exp_x_shifted, axis=0, keepdims=True) return all_weights, all_points