diff --git a/error_parity/classifiers.py b/error_parity/classifiers.py index 3ffd73e..507191b 100644 --- a/error_parity/classifiers.py +++ b/error_parity/classifiers.py @@ -279,11 +279,17 @@ def find_weights_given_two_points( f"should be 1!" ) - if not all(np.isclose(target_point, all_weights @ all_points)): + if not all(np.isclose(target_point, all_weights @ all_points, atol=1e-5)): raise RuntimeError( f"Triangulation of target point failed. " f"Target was {target_point}; got {all_weights @ all_points}." ) + + if (all_weights < 0).any(): + all_weights = np.asarray(all_weights) + x_max = np.amax(all_weights, axis=0, keepdims=True) + exp_x_shifted = np.exp(all_weights - x_max) + all_weights = exp_x_shifted / np.sum(exp_x_shifted, axis=0, keepdims=True) return all_weights, all_points