Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ The SpikeRNN framework consists of two complementary packages with a modern **ta
### Spiking RNN Package
- Rate-to-spike conversion maintaining task performance
- Biologically realistic leaky integrate-and-fire (LIF) neurons
- Spiking task evaluation classes (GoNogoSpikingTask, XORSpikingTask, ManteSpikingTask)
- SpikingTaskFactory for task-based evaluation
- Spiking task evaluation classes (GoNogoSpikingEvaluator, XORSpikingEvaluator, ManteSpikingEvaluator)
- SpikingEvaluatorFactory for task-based evaluation
- Scaling factor optimization for optimal conversion

## Task-Based Architecture
Expand All @@ -44,21 +44,20 @@ After installation, you can import both packages:

```python
from rate import FR_RNN_dale, set_gpu, create_default_config
from spiking import LIF_network_fnc, lambda_grid_search, evaluate_task
from spiking import LIF_network_fnc, lambda_grid_search

# Task-based architecture
from rate import TaskFactory
from spiking import SpikingTaskFactory
from spiking.eval_tasks import SpikingEvaluatorFactory
from rate.tasks import GoNogoTask, XORTask, ManteTask
from spiking.tasks import GoNogoSpikingTask, XORSpikingTask, ManteSpikingTask
from spiking.eval_tasks import GoNogoSpikingEvaluator, XORSpikingEvaluator, ManteSpikingEvaluator
```

## Quick Start: Task-Based Architecture

### Creating and Using Tasks

```python
from spikeRNN import TaskFactory

# Create task settings
settings = {
Expand All @@ -81,23 +80,22 @@ print(f"Generated {task.__class__.__name__} trial with label: {label}")
The framework provides 2 levels of evaluation:

```python
from spiking import SpikingTaskFactory, evaluate_task
from spiking.eval_tasks import SpikingEvaluatorFactory, evaluate_task

# Direct task evaluation (when you have a network instance, not necessarily trained)
spiking_task = SpikingTaskFactory.create_task('go_nogo')
performance = spiking_task.evaluate_performance(spiking_rnn, n_trials=100)
print(f"Accuracy: {performance['overall_accuracy']:.2f}")
spiking_evaluator = SpikingEvaluatorFactory.create_evaluator('go_nogo', settings)
performance = spiking_evaluator.evaluate_single_trial(model_path, scaling_factor)
print(f"Trial result: {performance}")

# High-level interface (when you have model files with trained weights)
performance = evaluate_task(
task_name='go_nogo',
model_dir='models/go-nogo',
n_trials=100,
model_path='models/go-nogo/model.mat',
save_plots=True
)

# Command line interface (for scripts and automation)
# python -m spiking.eval_tasks --task go_nogo --model_dir models/go-nogo/
# python -m spiking.eval_tasks --task go_nogo --model_path models/go-nogo/model.mat
```

### Extending with Custom Tasks
Expand Down Expand Up @@ -128,28 +126,32 @@ stimulus, target, label = custom_task.simulate_trial()
Create custom spiking evaluation tasks:

```python
from spiking.tasks import AbstractSpikingTask, SpikingTaskFactory
from spiking.eval_tasks import SpikingEvaluatorFactory
from rate.tasks import AbstractTask

class MyCustomSpikingTask(AbstractSpikingTask):
def get_default_settings(self):
return {'T': 200, 'custom_param': 1.0}

def get_sample_trial_types(self):
return ['type_a', 'type_b'] # For visualization
class MyCustomSpikingEvaluator(AbstractTask):
def __init__(self, settings):
super().__init__(settings)
self.eval_amp_thresh = settings.get('eval_amp_thresh', 0.7) # custom value

def generate_stimulus(self, trial_type=None):
# Generate stimulus logic
return stimulus, label
def validate_settings(self):
# Validation logic for custom task
required_keys = ['T', 'custom_param']
for key in required_keys:
if key not in self.settings:
raise ValueError(f"Missing required setting: {key}")

def evaluate_performance(self, spiking_rnn, n_trials=100):
# Multi-trial performance evaluation
return {'accuracy': 0.85, 'n_trials': n_trials}
def evaluate_single_trial(self, model_path: str, scaling_factor: float) -> int:
"""Evaluate a single trial for the custom task."""
# Custom evaluation logic here
# Return 1 if correct, 0 if incorrect
pass

# Register and use with evaluation system
SpikingTaskFactory.register_task('my_custom', MyCustomSpikingTask)
SpikingEvaluatorFactory._registry['my_custom'] = MyCustomSpikingEvaluator

# Now works with eval_tasks.py
python -m spiking.eval_tasks --task my_custom --model_dir models/custom/
python -m spiking.eval_tasks --task my_custom --model_path models/custom/model.mat
```

## Requirements
Expand Down Expand Up @@ -201,16 +203,20 @@ import numpy as np

# Optimize scaling factor
lambda_grid_search(
model_dir='models/go-nogo',
model_path='models/go-nogo/model.mat',
task_name='go-nogo',
n_trials=100,
scaling_factors=list(np.arange(25, 76, 5))
scaling_factors=list(np.arange(25, 76, 5)),
task_settings=settings
)

# Evaluate performance
performance = evaluate_task(
task_name='go_nogo',
model_dir='models/go-nogo/'
model_path='models/go-nogo/model.mat',
n_trails=50,
task_settings=settings,
all_trial_tasks=True
)
```

Expand Down
60 changes: 22 additions & 38 deletions docs/api/spiking/eval_tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,10 @@ The eval_tasks module provides a high-level evaluation interface that standardiz
* **Robust Error Handling**: Graceful handling of evaluation failures
* **Flexible Visualization**: Generic visualization system for any task type

**Evaluation Layers:**
**Evaluation:**

The framework provides three levels of evaluation:

1. **Core Task Methods**: Direct task evaluation (``task.evaluate_performance()``)
2. **High-Level Interface**: Complete workflow (``evaluate_task()``)
3. **Command-Line Interface**: Batch processing (``python -m spiking.eval_tasks``)
1. **High-Level Interface**: Complete workflow (``evaluate_task()``)
2. **Command-Line Interface**: Batch processing (``python -m spiking.eval_tasks``)

Usage Examples
----------------------------------------------------------------------------------
Expand All @@ -54,29 +51,29 @@ Usage Examples

# Evaluate any registered task
performance = evaluate_task(
task_name='go_nogo', # or 'xor', 'mante', custom tasks
model_dir='models/go-nogo/',
save_plots=True
task_name='go_nogo',
model_path='models/go-nogo/model.mat',
n_trials=50
)

print(f"Accuracy: {performance['overall_accuracy']:.3f}")
print(f"Performance: {performance}")

**Command-Line Interface:**

.. code-block:: bash

# Basic evaluation
python -m spiking.eval_tasks --task go_nogo --model_dir models/go-nogo/
python -m spiking.eval_tasks --task go_nogo --model_path models/go-nogo/model.mat

# With custom parameters
python -m spiking.eval_tasks \
--task xor \
--model_dir models/xor/ \
--model_path models/xor/model.mat \
--scaling_factor 45.0 \
--no_plots
--n_trials 50

# Custom task (after registration)
python -m spiking.eval_tasks --task my_custom --model_dir models/custom/
python -m spiking.eval_tasks --task my_custom --model_path models/custom/model.mat

**Custom Task Integration:**

Expand All @@ -95,8 +92,8 @@ Usage Examples

# 3. Evaluate using unified interface
performance = evaluate_task(
task_name='working_memory', # Now supported automatically
model_dir='models/working_memory/',
task_name='working_memory',
model_path='models/working_memory/model.mat',
)

Command-Line Arguments
Expand All @@ -108,22 +105,26 @@ Command-Line Arguments

Task to evaluate. Available tasks are dynamically determined from the factory registry.

.. option:: --model_dir MODEL_DIR
.. option:: --model_path MODEL_PATH

Directory containing the trained model .mat file.
Path to the trained model .mat file.

.. option:: --scaling_factor SCALING_FACTOR

Override scaling factor (uses value from .mat file if not provided).

.. option:: --no_plots
.. option:: --n_trials N_TRIALS

Skip generating visualization plots.
Number of trials to evaluate.

.. option:: --T T

Trial duration (timesteps) - overrides task default.

.. option:: --delay DELAY

Delay time (timesteps) - overrides task default.

.. option:: --stim_on STIM_ON

Stimulus onset time - overrides task default.
Expand All @@ -141,21 +142,4 @@ The system automatically loads trained rate RNN models from `.mat` files and ext

* Network weights and connectivity matrices
* Optimal scaling factors for rate-to-spike conversion
* Task-specific parameters and configurations

**Generic Visualization:**

The visualization system uses each task's ``get_sample_trial_types()`` method to determine what trial types to generate for plotting. This allows custom tasks to specify their own visualization patterns without modifying the evaluation code.

**Error Handling:**

The evaluation system includes comprehensive error handling:

* Graceful handling of missing model files
* Validation of task names against factory registry
* Recovery from trial generation failures
* Informative error messages for debugging

**Extensibility:**

The system is designed to be fully extensible. Any task that inherits from ``AbstractSpikingTask`` and is registered with ``SpikingTaskFactory`` can be evaluated using this unified interface.
* Task-specific parameters and configurations
13 changes: 0 additions & 13 deletions docs/api/spiking/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ Core Modules
:maxdepth: 1

lif_network
tasks
eval_tasks
lambda_grid_search
utils
Expand All @@ -22,9 +21,6 @@ Module Overview
**LIF_network_fnc.py**
Core function for converting rate RNNs to spiking networks and running LIF simulations.

**tasks.py**
Task-based architecture for spiking neural network evaluation with abstract base classes and concrete task implementations.

**eval_tasks.py**
Unified, extensible evaluation interface for spiking neural networks on cognitive tasks.

Expand All @@ -43,14 +39,6 @@ Quick Reference
* ``evaluate_task()``: Unified evaluation interface for all tasks
* ``lambda_grid_search()``: Optimize scaling factors

**Task Classes:**

* ``AbstractSpikingTask``: Base class for spiking task evaluation
* ``GoNogoSpikingTask``: Go-NoGo task for spiking networks
* ``XORSpikingTask``: XOR task for spiking networks
* ``ManteSpikingTask``: Mante task for spiking networks
* ``SpikingTaskFactory``: Factory for creating spiking task instances

**Configuration:**

* ``SpikingConfig``: Configuration dataclass for spiking RNN parameters
Expand All @@ -60,4 +48,3 @@ Quick Reference

* ``load_rate_model()``: Load rate model from `.mat` file
* ``format_spike_data()``: Format spike data for analysis
* ``SpikingTaskFactory.register_task()``: Register custom tasks
16 changes: 8 additions & 8 deletions docs/api/spiking/lambda_grid_search.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ Grid Search Parameters

The main grid search function accepts:

* ``model_dir`` (str): Directory containing trained rate RNN model .mat files
(default: '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0')
* ``task_name`` (str): Task type ('go-nogo', 'xor', or 'mante')
(default: 'go-nogo')
* ``model_path`` (str): Path to trained rate RNN model .mat file
(default: '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0/model.mat')
* ``task_name`` (str): Task type ('go_nogo', 'xor', or 'mante')
(default: 'go_nogo')
* ``n_trials`` (int): Number of trials to evaluate each scaling factor
(default: 100)
* ``scaling_factors`` (list): List of scaling factors to test
Expand Down Expand Up @@ -50,18 +50,18 @@ Example Usage

# Grid search with custom parameters
lambda_grid_search(
model_dir='models/go-nogo',
model_path='models/go-nogo/model.mat',
n_trials=50,
scaling_factors=list(range(30, 81, 5)),
task_name='go-nogo'
task_name='go_nogo'
)

Optimization Process
----------------------------------------------------

The grid search follows these steps:

1. **Load trained rate models** from the specified directory
1. **Load trained rate model** from the specified .mat file
2. **Generate test stimuli** appropriate for the task type
3. **Iterate through scaling factors** in the specified range
4. **Convert to spiking network** for each scaling factor
Expand Down Expand Up @@ -93,7 +93,7 @@ Different metrics are used depending on the task:
Output Files
----------------------------------------------------

The function modifies each input `.mat` file to include:
The function modifies the input `.mat` file to include:

* `opt_scaling_factor`: The optimal scaling factor found
* `all_perfs`: Performance scores for all tested scaling factors
Expand Down
Loading