From 4cf6220cf1d89491d3a428bb05075286f72e5aed Mon Sep 17 00:00:00 2001 From: davidwalter2 Date: Tue, 23 Dec 2025 09:30:46 -0500 Subject: [PATCH] Allow to freeze POI parameters; define tf.Variable to mask frozen parameters and allow to change after jitting --- rabbit/fitter.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index 013787f..fddf5b0 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -125,8 +125,6 @@ def __init__(self, indata, poi_model, options, do_blinding=False): self.parms = np.concatenate([self.poi_model.pois, self.indata.systs]) - self.init_frozen_params(options.freezeParameters) - # tf variable containing all fit parameters thetadefault = tf.zeros([self.indata.nsyst], dtype=self.indata.dtype) if self.poi_model.npoi > 0: @@ -136,6 +134,15 @@ def __init__(self, indata, poi_model, options, do_blinding=False): self.x = tf.Variable(xdefault, trainable=True, name="x") + # for freezing parameters + self.frozen_params = [] + self.frozen_params_mask = tf.Variable( + tf.zeros_like(self.x, dtype=tf.bool), trainable=False, dtype=tf.bool + ) + + self.frozen_indices = np.array([]) + self.freeze_params(options.freezeParameters) + # observed number of events per bin self.nobs = tf.Variable( tf.zeros_like(self.indata.data_obs), trainable=False, name="nobs" @@ -294,12 +301,26 @@ def load_fitresult(self, fitresult_file, fitresult_key): self.x.assign(xvals) self.cov.assign(tf.constant(covval)) - def init_frozen_params(self, frozen_parmeter_expressions): - self.frozen_params = match_regexp_params( - frozen_parmeter_expressions, self.parms + def update_frozen_params(self): + new_mask_np = np.isin(self.parms, self.frozen_params) + + self.frozen_params_mask.assign(new_mask_np) + self.frozen_indices = np.where(new_mask_np)[0] + + def freeze_params(self, frozen_parmeter_expressions): + self.frozen_params.extend( + match_regexp_params(frozen_parmeter_expressions, self.parms) + ) + self.update_frozen_params() + + def defreeze_params(self, unfrozen_parmeter_expressions): + unfrozen_parmeter = match_regexp_params( + unfrozen_parmeter_expressions, self.parms ) - self.frozen_params_mask = np.isin(self.parms, self.frozen_params) - self.frozen_indices = np.where(self.frozen_params_mask)[0] + self.frozen_params = [ + x for x in self.frozen_params if x not in unfrozen_parmeter + ] + self.update_frozen_params() def init_blinding_values(self, unblind_parameter_expressions=[]): @@ -1265,6 +1286,9 @@ def _compute_yields_noBBB(self, full=True): poi = self.get_blinded_poi() theta = self.get_blinded_theta() + poi = tf.where( + self.frozen_params_mask[: self.poi_model.npoi], tf.stop_gradient(poi), poi + ) theta = tf.where( self.frozen_params_mask[self.poi_model.npoi :], tf.stop_gradient(theta),