Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 2 additions & 31 deletions bin/rabbit_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import argparse
import time

import h5py
import numpy as np
import scipy

from rabbit import fitter, inputdata, io_tools, workspace
from rabbit import fitter, inputdata, workspace
from rabbit.mappings import helpers as mh
from rabbit.mappings import mapping as mp
from rabbit.poi_models import helpers as ph
Expand Down Expand Up @@ -473,35 +472,7 @@ def fit(args, fitter, ws, dofit=True):
edmval = None

if args.externalPostfit is not None:
# load results from external fit and set postfit value and covariance elements for common parameters
with h5py.File(args.externalPostfit, "r") as fext:
if "x" in fext.keys():
# fitresult from combinetf
x_ext = fext["x"][...]
parms_ext = fext["parms"][...].astype(str)
cov_ext = fext["cov"][...]
else:
# fitresult from rabbit
h5results_ext = io_tools.get_fitresult(fext, args.externalPostfitResult)
h_parms_ext = h5results_ext["parms"].get()

x_ext = h_parms_ext.values()
parms_ext = np.array(h_parms_ext.axes["parms"])
cov_ext = h5results_ext["cov"].get().values()

xvals = fitter.x.numpy()
covval = fitter.cov.numpy()
parms = fitter.parms.astype(str)

# Find common elements with their matching indices
common_elements, idxs, idxs_ext = np.intersect1d(
parms, parms_ext, assume_unique=True, return_indices=True
)
xvals[idxs] = x_ext[idxs_ext]
covval[np.ix_(idxs, idxs)] = cov_ext[np.ix_(idxs_ext, idxs_ext)]

fitter.x.assign(xvals)
fitter.cov.assign(tf.constant(covval))
fitter.load_fitresult(args.externalPostfit, args.externalPostfitResult)
else:
if dofit:
fitter.minimize()
Expand Down
2 changes: 1 addition & 1 deletion bin/rabbit_plot_hists.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ def make_plots(

axes = [a for a in hist_inclusive.axes]

if args.processGrouping is not None:
if args.processGrouping is not None and len(hist_stack):
hist_stack, labels, colors, procs = config.process_grouping(
args.processGrouping, hist_stack, procs
)
Expand Down
7 changes: 4 additions & 3 deletions bin/rabbit_plot_hists_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def plot_matrix(
ticklabels=None,
):

opts = dict()

if not isinstance(matrix, np.ndarray):
matrix = matrix.values()

Expand All @@ -163,10 +165,8 @@ def plot_matrix(
if args.correlation:
std_dev = np.sqrt(np.diag(matrix))
matrix = matrix / np.outer(std_dev, std_dev)
opts.update(dict(vmin=-1, vmax=1))

fig, ax = plt.subplots(figsize=(8, 6))

opts = dict()
if ticklabels is not None:
opts.update(
dict(
Expand All @@ -175,6 +175,7 @@ def plot_matrix(
)
)

fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
matrix,
cmap=cmap,
Expand Down
64 changes: 39 additions & 25 deletions bin/rabbit_plot_likelihood_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def parseArgs():
default=[],
help="Parameters to plot the likelihood scan",
)
parser.add_argument(
"--noHessian",
action="store_true",
help="Don't include Hessian likelihood approximation in plot",
)
parser.add_argument(
"--combine",
type=str,
Expand Down Expand Up @@ -128,9 +133,10 @@ def plot_scan(
subtitle=None,
titlePos=0,
xlim=None,
combine=None,
ylabel=r"$-2\,\Delta \log L$",
config={},
combine=None,
no_hessian=False,
):

xlabel = getattr(config, "systematics_labels", {}).get(param, param)
Expand All @@ -140,42 +146,54 @@ def plot_scan(
x = np.array(h_scan.axes["scan"]).astype(float)[mask]
y = h_scan.values()[mask] * 2

if xlim is None:
xlim = (min(x), max(x))

fig, ax = plot_tools.figure(
x,
xlabel,
ylabel,
xlim=(min(x), max(x)),
xlim=xlim,
ylim=(min(y), max(y)), # logy=args.logy
)

ax.axhline(y=1, color="gray", linestyle="--", alpha=0.5)
ax.axhline(y=4, color="gray", linestyle="--", alpha=0.5)

parabola_vals = param_value + np.linspace(
-3 * param_variance**0.5, 3 * param_variance**0.5, 100
)
parabola_nlls = 1 / param_variance * (parabola_vals - param_value) ** 2
ax.plot(
parabola_vals,
parabola_nlls,
marker="",
markerfacecolor="none",
color="red",
linestyle="-",
label="Hessian",
)
if not no_hessian:
parabola_vals = param_value + np.linspace(
-3 * param_variance**0.5, 3 * param_variance**0.5, 100
)
parabola_nlls = 1 / param_variance * (parabola_vals - param_value) ** 2
ax.plot(
parabola_vals,
parabola_nlls,
marker="",
markerfacecolor="none",
color="red",
linestyle="-",
label="Hessian",
)

ax.plot(
x,
y,
marker="x",
color="blue",
label="Likelihood scan",
marker=None, # "x",
color="black",
label="Likelihood scan" if combine is None else "Rabbit",
markeredgewidth=2,
linewidth=2,
)

if combine is not None:
ax.plot(*combine, marker="o", color="orange", label="Combine")
ax.plot(
*combine,
marker="o",
linestyle=None,
linewidth=0,
color="orange",
label="Combine",
)

if h_contours is not None:
for i, cl in enumerate(h_contours.axes["confidence_level"]):
Expand All @@ -201,12 +219,7 @@ def plot_scan(
ax.legend(loc="upper right")

plot_tools.add_decor(
ax,
title,
subtitle,
data=False,
lumi=None,
loc=titlePos,
ax, title, subtitle, data=True, lumi=None, loc=titlePos, no_energy=True
)

return fig
Expand Down Expand Up @@ -269,6 +282,7 @@ def main():
xlim=args.xlim,
config=config,
combine=(vals, nlls) if args.combine is not None else None,
no_hessian=args.noHessian,
)
os.makedirs(args.outpath, exist_ok=True)
outfile = os.path.join(args.outpath, f"nll_scan_{param}")
Expand Down
65 changes: 51 additions & 14 deletions bin/rabbit_plot_pulls_and_impacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,13 @@ def parseArgs():
nargs="+",
help="Print impacts on observables use '-m <mapping> channel axes' for mapping results.",
)
parser.add_argument(
"--mappingRef",
default=None,
type=str,
nargs="+",
help="Print impacts on observables use '-m <mapping> channel axes' for mapping results for reference.",
)
parser.add_argument(
"-r",
"--referenceFile",
Expand Down Expand Up @@ -1227,7 +1234,9 @@ def main():
translate_label = getattr(config, "impact_labels", {})

fitresult, meta = io_tools.get_fitresult(args.inputFile, args.result, meta=True)
if args.referenceFile is not None or args.refResult is not None:
if any(
x is not None for x in [args.referenceFile, args.refResult, args.mappingRef]
):
referenceFile = (
args.referenceFile if args.referenceFile is not None else args.inputFile
)
Expand Down Expand Up @@ -1276,14 +1285,36 @@ def main():
"Only global impacts on observables is implemented (use --globalImpacts)"
)

def get_mapping_key(result, key):

res = result.get("mappings", fitresult.get("physics_models"))
if key in res.keys():
channels = res[key]["channels"]
return channels, key
else:
keys = [key for key in res.keys() if key.startswith(key)]
if len(keys) == 0:
raise ValueError(
f"Mapping {key} not found, available mappings are: {res.keys()}"
)

channels = res[keys[0]]["channels"]
return channels, keys[0]

mapping_key = " ".join(args.mapping)
results = fitresult.get("mappings", fitresult.get("physics_models"))

if mapping_key in results.keys():
channels = results[mapping_key]["channels"]
else:
keys = [key for key in results.keys() if key.startswith(mapping_key)]
channels = results[keys[0]]["channels"]
channels, mapping_key = get_mapping_key(fitresult, mapping_key)

if fitresult_ref:
mapping_key_ref = (
" ".join(args.mappingRef)
if args.mappingRef is not None
else mapping_key
)

channels_ref, mapping_key_ref = get_mapping_key(
fitresult_ref, mapping_key_ref
)

for channel, hists in channels.items():

Expand All @@ -1299,15 +1330,21 @@ def main():

hist = hists[key].get()

# TODO: implement ref
# hist_ref
# hist_total_ref
if fitresult_ref:
if channel in channels_ref.keys():
channel_ref = channel
elif len(channels_ref.keys()) == 1:
channel_ref = [v for v in channels_ref.keys()][0]
else:
raise NotImplementedError(
f"Could not decide which is the right channel from reference file with channels: {channels_ref.keys()}"
)

if fitresult_ref is not None:
results_ref = fitresult_ref.get(
"mappings", fitresult_ref.get("physics_models")
res_ref = fitresult_ref.get(
"mappings", fitresult.get("physics_models")
)
hists_ref = results_ref[mapping_key]["channels"][channel]
hists_ref = res_ref[mapping_key_ref]["channels"][channel_ref]

hist_ref = hists_ref[key].get()
hist_total_ref = hists_ref["hist_postfit_inclusive"].get()

Expand Down
33 changes: 33 additions & 0 deletions rabbit/fitter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import hashlib
import re

import h5py
import numpy as np
import scipy
import tensorflow as tf
import tensorflow_probability as tfp
from wums import logging

from rabbit import io_tools
from rabbit import tfhelpers as tfh

logger = logging.child_logger(__name__)
Expand Down Expand Up @@ -261,6 +263,37 @@ def __init__(self, indata, poi_model, options, do_blinding=False):
and ((not self.binByBinStat) or self.binByBinStatType == "normal-additive")
)

def load_fitresult(self, fitresult_file, fitresult_key):
# load results from external fit and set postfit value and covariance elements for common parameters
with h5py.File(fitresult_file, "r") as fext:
if "x" in fext.keys():
# fitresult from combinetf
x_ext = fext["x"][...]
parms_ext = fext["parms"][...].astype(str)
cov_ext = fext["cov"][...]
else:
# fitresult from rabbit
h5results_ext = io_tools.get_fitresult(fext, fitresult_key)
h_parms_ext = h5results_ext["parms"].get()

x_ext = h_parms_ext.values()
parms_ext = np.array(h_parms_ext.axes["parms"])
cov_ext = h5results_ext["cov"].get().values()

xvals = self.x.numpy()
covval = self.cov.numpy()
parms = self.parms.astype(str)

# Find common elements with their matching indices
common_elements, idxs, idxs_ext = np.intersect1d(
parms, parms_ext, assume_unique=True, return_indices=True
)
xvals[idxs] = x_ext[idxs_ext]
covval[np.ix_(idxs, idxs)] = cov_ext[np.ix_(idxs_ext, idxs_ext)]

self.x.assign(xvals)
self.cov.assign(tf.constant(covval))

def init_frozen_params(self, frozen_parmeter_expressions):
self.frozen_params = match_regexp_params(
frozen_parmeter_expressions, self.parms
Expand Down
1 change: 1 addition & 0 deletions rabbit/mappings/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"Project": "project",
"Normalize": "project",
"Ratio": "ratio",
"Difference": "ratio",
"Normratio": "ratio",
"Asymmetry": "ratio",
"AngularCoefficients": "angular_coefficients",
Expand Down
16 changes: 16 additions & 0 deletions rabbit/mappings/ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,19 @@ def compute_flat(self, params, observables):
exp = (num - den) / (num + den)
exp = tf.reshape(exp, [-1])
return exp


class Difference(Ratio):
"""
Same as Ratio but compute the difference of numerator and denominator
"""

def init(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def compute_flat(self, params, observables):
num = self.num.select(observables, normalize=True, inclusive=True)
den = self.den.select(observables, normalize=True, inclusive=True)
exp = num - den
exp = tf.reshape(exp, [-1])
return exp
Loading