-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Labels
Description
psyki-python/psyki/ski/kbann/__init__.py
Line 38 in 9dfb86b
| # TODO: analyse this warning that sometimes comes out, this should not be armful. |
"""
# self.feature_mapping: dict[str, int] = feature_mapping
# Use as default fuzzifiers SubNetworkBuilder.
# TODO: analyse this warning that sometimes comes out, this should not be armful.
tf.get_logger().setLevel('ERROR')
self._predictor = predictor
self._fuzzifier = Fuzzifier.get(fuzzifier)([self._predictor.input, feature_mapping, omega])
self._fuzzy_functions: Iterable[Callable] = ()
self.gamma = gamma
class ConstrainedModel(EnrichedModel):
def __init__(self, model: Model, gamma: float, custom_objects: dict):
super().__init__(model, custom_objects)
self.gamma = gamma
self.init_weights = copy.deepcopy(self.weights)
class CustomLoss(Loss):
def __init__(self, original_loss: Callable, model: Model, init_weights, gamma: float):
self.original_loss = original_loss
self.model = model
self.init_weights = init_weights
self.gamma = gamma
super().__init__()
def call(self, y_true, y_pred):
return self.original_loss(y_true, y_pred) + self.gamma * self._cost_factor()
def _cost_factor(self):
weights_quadratic_diff = 0
for init_weight, current_weight in zip(self.init_weights, self.model.weights):
weights_quadratic_diff += tf.math.reduce_sum((init_weight - current_weight) ** 2)
# weights_quadratic_diff = tf.math.reduce_sum((tf.ragged.constant(self.init_weights) - tf.ragged.constant(self.weights)) ** 2)
return weights_quadratic_diff / (1 + weights_quadratic_diff)
def copy(self) -> EnrichedModel:
with custom_object_scope(self.custom_objects):
model = model_deep_copy(Model(self.input, self.output))
return KBANN.ConstrainedModel(model, self.gamma, self.custom_objects)
def loss_function(self, original_function: Callable) -> Callable:
return self.CustomLoss(original_function, self, self.init_weights, self.gamma)
def inject(self, rules: List[Formula]) -> Model:
# Prevent side effect on the original rules during optimization.