Please visit the Documentation for further information or refer to the Publication
- Project Description
- Getting Started
- Key Features
- Installation Instructions
- Usage
- Advanced Usage
- Extending Functionality
- Contributing Guidelines
- License Information
- Publication
PyTorch Bio Transformations is a Python library that implements biologically inspired modifications to artificial neural networks, based on research on dendritic spine dynamics. It aims to explore and enhance the learning capabilities of neural networks by mimicking the plasticity and stability characteristics observed in biological synapses.
This project is primarily targeted at researchers and developers in the fields of machine learning and computational neuroscience who are interested in exploring bio-inspired approaches to augment neural network performance.
# We recommend to create a new environment
conda create -n biomod python=3.11
conda activate biomod
# Install the package
pip install numpy torch torchvision torchaudio
pip install pytorch_bio_transformations
# Convert your PyTorch model in just 3 lines
from bio_transformations import BioConverter
converter = BioConverter()
bio_model = converter(your_pytorch_model)
# Use bio_model as you would a regular PyTorch model
# During training, apply bio-inspired mechanisms
optimizer.zero_grad()
loss.backward()
bio_model.fuzzy_learning_rates() # Apply diverse learning rates
bio_model.crystallize() # Stabilize well-optimized weights
optimizer.step()Bio Transformations implements several biologically inspired methods, each mimicking specific aspects of neuronal behavior:
-
Synaptic Diversity (
fuzzy_learning_rates): Implements diverse learning rates for different "synapses" (weights), mimicking the variability observed in biological synapses. -
Structural Plasticity (
rejuvenate_weights): Simulates spine turnover by randomly reinitializing certain weights, allowing for the "formation" of new connections and the "pruning" of others. -
Synaptic Stabilization (
crystallize): Mimics the stabilization of frequently used synapses by reducing learning rates for well-optimized weights. -
Multi-synaptic Connectivity (
weight_splitting): Allows multiple "synapses" (sub-weights) to exist for each connection, enhancing the reliability and flexibility of neural circuits. -
Volume-dependent Plasticity (
volume_dependent_lr): Adjusts learning rates based on weight magnitude (analogous to spine volume), where larger weights have smaller, less variable learning rates. -
Homeostatic Plasticity (
scale_grad): Implements synaptic scaling to maintain overall network stability while allowing for learning. -
Dale's Principle (
enforce_dales_principle): Ensures that all outgoing weights from a given artificial "neuron" have the same sign, mimicking the constraints imposed by neurotransmitter types.
These methods work in concert to create a learning process that more closely resembles the dynamics observed in biological neural networks, potentially leading to improved learning and generalization in artificial neural networks.
# We recommend to create a new environment
conda create -n biomod python=3.11
conda activate biomodYou can install Bio Transformations using pip or from source.
Install PyTorch PyTorch.org
GPU/CUDA12.4:
pip install torch torchvision torchaudioCPU:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpuGPU/CUDA12.4:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124CPU:
pip install torch torchvision torchaudiopip install torch torchvision torchaudiopip install pytorch_bio_transformationsgit clone https://github.com/CeadeS/pytorch_bio_transformations
cd pytorch_bio_transformations
pip install -r requirements.txt
pip install -e .import torch
import torch.nn as nn
from bio_transformations import BioConverter, BioConfig
# Define your model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
# Create and convert your model
model = SimpleModel()
converter = BioConverter(
fuzzy_learning_rate_factor_nu=0.16, # Controls the diversity in learning rates
dampening_factor=0.6, # Controls the stability increase during crystallization
crystal_thresh=4.5e-05 # Threshold for identifying weights to crystallize
)
bio_model = converter(model)
# Use bio_model as you would a regular PyTorch model
x = torch.randn(1, 10)
output = bio_model(x)
print(output)import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from bio_transformations import BioConverter
# Define a simple CNN for MNIST
class MNISTNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.log_softmax(self.fc2(x), dim=1)
return x
def main():
# Training settings
batch_size = 64
epochs = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Prepare MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000)
# Create a bio-inspired model
model = MNISTNet().to(device)
converter = BioConverter(
fuzzy_learning_rate_factor_nu=0.16, # Controls learning rate diversity
dampening_factor=0.7, # For synaptic stabilization
crystal_thresh=4.5e-05, # Threshold for crystallization
rejuvenation_parameter_dre=10.0 # Controls weight rejuvenation
)
bio_model = converter(model)
# Define optimizer
optimizer = optim.SGD(bio_model.parameters(), lr=0.01)
# Training loop
for epoch in range(1, epochs + 1):
# Training phase
bio_model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# Forward pass
optimizer.zero_grad()
output = bio_model(data)
loss = F.nll_loss(output, target)
# Backward pass
loss.backward()
# Apply bio-inspired modifications
bio_model.fuzzy_learning_rates() # Apply diverse learning rates
if batch_idx % 100 == 0:
bio_model.crystallize() # Stabilize important weights periodically
# Update weights
optimizer.step()
# Print progress
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]'
f' Loss: {loss.item():.4f}')
# Apply weight rejuvenation at the end of each epoch
bio_model.rejuvenate_weights()
# Testing phase
bio_model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = bio_model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
if __name__ == "__main__":
main()Bio Transformations offers extensive configuration options through the BioConfig class:
from bio_transformations import BioConverter, BioConfig
from bio_transformations.bio_config import Distribution
# Create a detailed configuration
config = BioConfig(
# Fuzzy learning rate parameters
fuzzy_learning_rate_factor_nu=0.16, # Controls the variability in learning rates
fuzzy_lr_distribution=Distribution.NORMAL, # Distribution strategy for learning rates
fuzzy_lr_dynamic=True, # Whether to update learning rates during training
# Synaptic stabilization parameters
dampening_factor=0.6, # Factor for reducing learning rates during crystallization
crystal_thresh=4.5e-05, # Threshold for identifying weights to crystallize
# Structural plasticity parameters
rejuvenation_parameter_dre=8.0, # Controls the rate of weight rejuvenation
# Multi-synaptic connectivity parameters
weight_splitting_Gamma=2, # Number of sub-synapses per connection
weight_splitting_activation_function=nn.ReLU(), # Activation function for weight splitting
# Volume-dependent plasticity parameters
base_lr=0.1, # Base learning rate for volume-dependent plasticity
stability_factor=2.0, # Controls how quickly stability increases with weight size
lr_variability=0.2 # Controls the amount of variability in learning rates
)
converter = BioConverter(config=config)Bio Transformations supports various distribution strategies for fuzzy learning rates:
from bio_transformations.bio_config import Distribution
# Different distribution strategies
basic_config = BioConfig(fuzzy_lr_distribution=Distribution.BASELINE) # All parameters = 1.0 (no variability)
uniform_config = BioConfig(fuzzy_lr_distribution=Distribution.UNIFORM) # Uniform distribution around 1.0
normal_config = BioConfig(fuzzy_lr_distribution=Distribution.NORMAL) # Normal distribution centered at 1.0
lognormal_config = BioConfig(fuzzy_lr_distribution=Distribution.LOGNORMAL) # Log-normal with mean 1.0
gamma_config = BioConfig(fuzzy_lr_distribution=Distribution.GAMMA) # Gamma distribution (positive, skewed)
beta_config = BioConfig(fuzzy_lr_distribution=Distribution.BETA) # Beta distribution scaled
layer_config = BioConfig(fuzzy_lr_distribution=Distribution.LAYER_ADAPTIVE) # Layer-dependent variability
weight_config = BioConfig(fuzzy_lr_distribution=Distribution.WEIGHT_ADAPTIVE) # Weight-dependent scaling
temporal_config = BioConfig(fuzzy_lr_distribution=Distribution.TEMPORAL, fuzzy_lr_dynamic=True) # Evolves over time
activity_config = BioConfig(fuzzy_lr_distribution=Distribution.ACTIVITY, fuzzy_lr_dynamic=True) # Based on activationYou can update the configuration of a BioConverter after it has been created:
converter = BioConverter()
converter.update_config(
dampening_factor=0.7,
crystal_thresh=5e-05
)
# Or create a converter from a dictionary
config_dict = {
'fuzzy_learning_rate_factor_nu': 0.2,
'dampening_factor': 0.7
}
converter = BioConverter.from_dict(config_dict)You can convert existing model instances:
pretrained_model = torchvision.models.resnet18(pretrained=True)
bio_model = converter.convert(pretrained_model)Or use the converter as a decorator:
@converter
class BioResNet(nn.Module):
def __init__(self):
super().__init__()
self.backbone = torchvision.models.resnet18(pretrained=True)
# Additional layers...You can extend Bio Transformations with your own bio-inspired methods:
- Add the Function to BioModule
# In bio_transformations/bio_module.py
class BioModule(nn.Module):
# Add your function to the exposed_functions list
exposed_functions = (
"rejuvenate_weights",
"crystallize",
"fuzzy_learning_rates",
"volume_dependent_lr",
"my_new_function", # <-- Add your function name here
# ... other existing functions
)
# Add your function implementation
def my_new_function(self) -> None:
"""
Your new bio-inspired function.
This function implements a new bio-inspired mechanism for neural networks.
"""
# Implementation goes here
with torch.no_grad():
# Example: Add random noise to weights
noise = torch.randn_like(self.get_parent().weight.data) * 0.01
self.get_parent().weight.data += noise- Add parameters to BioConfig if needed
# In bio_transformations/bio_config.py
class BioConfig(NamedTuple):
# Existing parameters...
my_new_parameter: float = 0.5 # Default value for your new parameter- Create a test case in test_biomodule.py
# In test_biomodule.py
def test_my_new_function():
"""Test the my_new_function method of BioModule."""
linear_layer = nn.Linear(10, 10)
bio_mod = BioModule(lambda: linear_layer)
# Save initial weights for comparison
initial_weights = linear_layer.weight.data.clone()
# Call your new function
bio_mod.my_new_function()
# Verify the function had the expected effect
assert not torch.allclose(linear_layer.weight.data, initial_weights), "Weights should change after calling my_new_function"- Update documentation in the appropriate RST files
.. method:: my_new_function()
Your new bio-inspired function.
This function implements a new bio-inspired mechanism for neural networks.
It uses the `my_new_parameter` from the configuration to control behavior.You can also create your own custom BioModule class with specialized functionality:
from bio_transformations.bio_module import BioModule
class CustomBioModule(BioModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Additional initialization
def custom_bio_method(self):
# Your custom bio-inspired logic here
pass
# Update BioModule.exposed_functions to include your new method
CustomBioModule.exposed_functions = BioModule.exposed_functions + ("custom_bio_method",)
# Use CustomBioModule in your BioConverter
class CustomBioConverter(BioConverter):
def _bio_modulize(self, module):
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.add_module('bio_mod', CustomBioModule(lambda: module, **self.bio_module_params))
# Use your custom converter
custom_converter = CustomBioConverter()
bio_model = custom_converter(model)We welcome contributions to Bio Transformations! Please follow these steps:
- Fork the repository and create your branch from
main. - Make changes and ensure all tests pass.
- Add tests for new functionality.
- Update documentation to reflect changes.
- Submit a pull request with a clear description of your changes.
Please adhere to the existing code style and include appropriate comments.
This project is licensed under the MIT License. See the LICENSE file for details.
For more detailed information about the project and its underlying research, please refer to our paper: [DOI]