Skip to content
Merged

Dev #29

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions aide_predict/bespoke_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init_subclass__(cls, **kwargs):
Initialize the subclass and register mixin hooks.

This method is called when a subclass is created. It collects and registers
all hooks from mixins in the inheritance chain.
all hooks from mixins in the inheritance chain, avoiding duplicates.
"""
super().__init_subclass__(**kwargs)

Expand All @@ -169,23 +169,47 @@ def __init_subclass__(cls, **kwargs):
cls._pre_transform_hooks = []
cls._post_transform_hooks = []

# Keep track of function objects we've already added to avoid duplicates
seen_init_handlers = set()
seen_pre_fit_hooks = set()
seen_post_fit_hooks = set()
seen_pre_transform_hooks = set()
seen_post_transform_hooks = set()

# Collect hooks from all mixin bases
for base in cls.__mro__[1:]: # Skip self, look at all parents
# Init handlers
if hasattr(base, '_init_handler') and callable(base._init_handler):
cls._mixin_init_handlers.append(base._init_handler)
handler = base._init_handler
if handler not in seen_init_handlers:
cls._mixin_init_handlers.append(handler)
seen_init_handlers.add(handler)

# Fit hooks
if hasattr(base, '_pre_fit_hook') and callable(base._pre_fit_hook):
cls._pre_fit_hooks.append(base._pre_fit_hook)
hook = base._pre_fit_hook
if hook not in seen_pre_fit_hooks:
cls._pre_fit_hooks.append(hook)
seen_pre_fit_hooks.add(hook)

if hasattr(base, '_post_fit_hook') and callable(base._post_fit_hook):
cls._post_fit_hooks.append(base._post_fit_hook)
hook = base._post_fit_hook
if hook not in seen_post_fit_hooks:
cls._post_fit_hooks.append(hook)
seen_post_fit_hooks.add(hook)

# Transform hooks
if hasattr(base, '_pre_transform_hook') and callable(base._pre_transform_hook):
cls._pre_transform_hooks.append(base._pre_transform_hook)
hook = base._pre_transform_hook
if hook not in seen_pre_transform_hooks:
cls._pre_transform_hooks.append(hook)
seen_pre_transform_hooks.add(hook)

if hasattr(base, '_post_transform_hook') and callable(base._post_transform_hook):
cls._post_transform_hooks.append(base._post_transform_hook)
hook = base._post_transform_hook
if hook not in seen_post_transform_hooks:
cls._post_transform_hooks.append(hook)
seen_post_transform_hooks.add(hook)

# Stop at ProteinModelWrapper (don't include its parents)
if base is ProteinModelWrapper:
Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- bioconda
dependencies:
- python<=3.10
- python<=3.11
- hmmer=3.3.2
- mmseqs2
- plmc
Expand Down
Loading