diff --git a/aide_predict/bespoke_models/base.py b/aide_predict/bespoke_models/base.py index 6511b99..4e5afb7 100644 --- a/aide_predict/bespoke_models/base.py +++ b/aide_predict/bespoke_models/base.py @@ -158,7 +158,7 @@ def __init_subclass__(cls, **kwargs): Initialize the subclass and register mixin hooks. This method is called when a subclass is created. It collects and registers - all hooks from mixins in the inheritance chain. + all hooks from mixins in the inheritance chain, avoiding duplicates. """ super().__init_subclass__(**kwargs) @@ -169,23 +169,47 @@ def __init_subclass__(cls, **kwargs): cls._pre_transform_hooks = [] cls._post_transform_hooks = [] + # Keep track of function objects we've already added to avoid duplicates + seen_init_handlers = set() + seen_pre_fit_hooks = set() + seen_post_fit_hooks = set() + seen_pre_transform_hooks = set() + seen_post_transform_hooks = set() + # Collect hooks from all mixin bases for base in cls.__mro__[1:]: # Skip self, look at all parents # Init handlers if hasattr(base, '_init_handler') and callable(base._init_handler): - cls._mixin_init_handlers.append(base._init_handler) + handler = base._init_handler + if handler not in seen_init_handlers: + cls._mixin_init_handlers.append(handler) + seen_init_handlers.add(handler) # Fit hooks if hasattr(base, '_pre_fit_hook') and callable(base._pre_fit_hook): - cls._pre_fit_hooks.append(base._pre_fit_hook) + hook = base._pre_fit_hook + if hook not in seen_pre_fit_hooks: + cls._pre_fit_hooks.append(hook) + seen_pre_fit_hooks.add(hook) + if hasattr(base, '_post_fit_hook') and callable(base._post_fit_hook): - cls._post_fit_hooks.append(base._post_fit_hook) + hook = base._post_fit_hook + if hook not in seen_post_fit_hooks: + cls._post_fit_hooks.append(hook) + seen_post_fit_hooks.add(hook) # Transform hooks if hasattr(base, '_pre_transform_hook') and callable(base._pre_transform_hook): - cls._pre_transform_hooks.append(base._pre_transform_hook) + hook = base._pre_transform_hook + if hook not in seen_pre_transform_hooks: + cls._pre_transform_hooks.append(hook) + seen_pre_transform_hooks.add(hook) + if hasattr(base, '_post_transform_hook') and callable(base._post_transform_hook): - cls._post_transform_hooks.append(base._post_transform_hook) + hook = base._post_transform_hook + if hook not in seen_post_transform_hooks: + cls._post_transform_hooks.append(hook) + seen_post_transform_hooks.add(hook) # Stop at ProteinModelWrapper (don't include its parents) if base is ProteinModelWrapper: diff --git a/environment.yaml b/environment.yaml index 9afd7b5..3a36746 100644 --- a/environment.yaml +++ b/environment.yaml @@ -4,7 +4,7 @@ channels: - conda-forge - bioconda dependencies: - - python<=3.10 + - python<=3.11 - hmmer=3.3.2 - mmseqs2 - plmc