-
Notifications
You must be signed in to change notification settings - Fork 3
Description
I was trying to run causality_unified_exp.ipynb and the following error is showing up in the function get_AUCs
iter 29920 loss_total: 0.069716 loss_train: 0.005160 loss_test: 0.000007 noise_norm: 0.021518 reg: 0.000000655 lr = 0.000000
iter 29940 loss_total: 0.071126 loss_train: 0.006571 loss_test: 0.000007 noise_norm: 0.021518 reg: 0.000000655 lr = 0.000000
iter 29960 loss_total: 0.072121 loss_train: 0.007565 loss_test: 0.000007 noise_norm: 0.021518 reg: 0.000000655 lr = 0.000000
iter 29980 loss_total: 0.069464 loss_train: 0.004908 loss_test: 0.000007 noise_norm: 0.021518 reg: 0.000000655 lr = 0.000000
ValueError Traceback (most recent call last)
in
128 plot_interval = 5 if info_estimate_mode == "var" else 200
129 )
--> 130 item_dict["metrics"] = get_AUCs(item_dict["result"][0], A_whole, neglect_idx = neglect_idx)
131 pickle.dump(learned_dict, open(filename, "wb"))
132
~/Documents/causal_climate/causal/causality/util_causality.py in get_AUCs(causality_value, A_whole, neglect_idx, verbose)
81 if len(causality_truth.shape) == 3:
82 causality_truth = causality_truth.any(-2)
---> 83 ROC_AUC, ROC_AUC_mean, ROC_AUC_list = get_ROC_AUC(causality_truth, causality_value, neglect_idx = neglect_idx)
84 PR_AUC, PR_AUC_mean, PR_AUC_list = get_PR_AUC(causality_truth, causality_value, neglect_idx = neglect_idx)
85 if verbose:
~/Documents/causal_climate/causal/causality/util_causality.py in get_ROC_AUC(causality_truth, causality_value, neglect_idx)
38 causality_value = to_np_array(causality_value)
39 assert not np.isinf(causality_value).any() and not np.isnan(causality_value).any()
---> 40 ROC_AUC = roc_auc_score(flatten_matrix(causality_truth, neglect_idx), flatten_matrix(causality_value, neglect_idx))
41 ROC_AUC_list = []
42 for i in range(len(causality_truth)):
~/.local/lib/python3.6/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
61 extra_args = len(args) - len(all_args)
62 if extra_args <= 0:
---> 63 return f(*args, **kwargs)
64
65 # extra_args > 0
~/.local/lib/python3.6/site-packages/sklearn/metrics/_ranking.py in roc_auc_score(y_true, y_score, average, sample_weight, max_fpr, multi_class, labels)
543 max_fpr=max_fpr),
544 y_true, y_score, average,
--> 545 sample_weight=sample_weight)
546 else: # multilabel-indicator
547 return _average_binary_score(partial(_binary_roc_auc_score,
~/.local/lib/python3.6/site-packages/sklearn/metrics/_base.py in _average_binary_score(binary_metric, y_true, y_score, average, sample_weight)
75
76 if y_type == "binary":
---> 77 return binary_metric(y_true, y_score, sample_weight=sample_weight)
78
79 check_consistent_length(y_true, y_score, sample_weight)
~/.local/lib/python3.6/site-packages/sklearn/metrics/_ranking.py in _binary_roc_auc_score(y_true, y_score, sample_weight, max_fpr)
329
330 fpr, tpr, _ = roc_curve(y_true, y_score,
--> 331 sample_weight=sample_weight)
332 if max_fpr is None or max_fpr == 1:
333 return auc(fpr, tpr)
~/.local/lib/python3.6/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
61 extra_args = len(args) - len(all_args)
62 if extra_args <= 0:
---> 63 return f(*args, **kwargs)
64
65 # extra_args > 0
~/.local/lib/python3.6/site-packages/sklearn/metrics/_ranking.py in roc_curve(y_true, y_score, pos_label, sample_weight, drop_intermediate)
912 """
913 fps, tps, thresholds = _binary_clf_curve(
--> 914 y_true, y_score, pos_label=pos_label, sample_weight=sample_weight)
915
916 # Attempt to drop thresholds corresponding to points in between and
~/.local/lib/python3.6/site-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
691 raise ValueError("{0} format is not supported".format(y_type))
692
--> 693 check_consistent_length(y_true, y_score, sample_weight)
694 y_true = column_or_1d(y_true)
695 y_score = column_or_1d(y_score)
~/.local/lib/python3.6/site-packages/sklearn/utils/validation.py in check_consistent_length(*arrays)
318 if len(uniques) > 1:
319 raise ValueError("Found input variables with inconsistent numbers of"
--> 320 " samples: %r" % [int(l) for l in lengths])
321
322
ValueError: Found input variables with inconsistent numbers of samples: [90, 210]