Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions brainevent/_fcn/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,15 @@ def _binary_fcnmv_cuda_kernel(
if transpose:
# Scatter mode: if is_active(spikes[i]) → output[indices[i,k]] += weights[i,k]
# Always TPR (thread-per-row) — atomicAdd contention is the bottleneck at all n_conn.
kernel_name = f'fcn_binary_mv.binary_fcnmv_scatter{mode_sfx}{spike_sfx}{sfx}'
#kernel_name = f'fcn_binary_mv.binary_fcnmv_scatter{mode_sfx}{spike_sfx}{sfx}'
kernel_name = f'fcn_binary_mv.binary_fcnmv_scatter{mode_sfx}_bool{sfx}'
else:
# Gather mode: y[i] = sum_k weights[i,k] * is_active(spikes[indices[i,k]])
# Auto-dispatch inside CUDA: TPR for n_conn<=512, MR for n_conn>512.
kernel_name = f'fcn_binary_mv.binary_fcnmv_gather{mode_sfx}{spike_sfx}{sfx}'
kernel_name = f'fcn_binary_mv.binary_fcnmv_gather{mode_sfx}_bool{sfx}'

def kernel(weights, indices, spikes):
#spikes = u.math.asarray(spikes, dtype=bool)
spikes = u.math.asarray(spikes, dtype=bool)
return jax.ffi.ffi_call(kernel_name, out_info)(weights, indices, spikes)

return kernel
Comment on lines 275 to 284
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): spk_f construction is inconsistent and appears unused/incorrectly overwritten.

In the non-transpose branch, spk_f is first built as a float mask and then immediately replaced by u.math.asarray(spikes, dtype=bool), making the earlier construction dead code. Also, in the shown snippet spk_f is never used, so the conversion appears to have no effect.

Please either:

  • remove the initial spk_f construction and keep only the representation actually needed, or
  • ensure spk_f is the value converted and passed into the kernel if that’s what the kernel expects.

Aligning this with the kernel’s expected type will avoid subtle type/semantics issues and confusion from unused variables.

Expand All @@ -300,6 +301,7 @@ def kernel(weights, indices, spikes):
else:
spk_f = (spikes > 0).astype(weights.dtype)

spk_f = u.math.asarray(spikes, dtype=bool)

if transpose:
# Scatter: y[indices[i,k]] += weights[i,k] * spk_f[i]
Expand Down
182 changes: 136 additions & 46 deletions brainevent/_fcn/binary_fcnmv.cu

Large diffs are not rendered by default.

133 changes: 127 additions & 6 deletions dev/fcn/CsvOutput.py → dev/fcn/BenchmarkTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
suffix: str = '',
output_dir: str | None = None,
append: bool = False,

) -> None:
self.name = CSV_name
self.suffix = suffix
Expand All @@ -64,6 +63,7 @@ def __init__(
self.output_dir = Path(output_dir) if output_dir else base / 'results'
self.output_dir.mkdir(parents=True, exist_ok=True)
self.append = append
self._tags: dict = {}

self.width = 70

Expand All @@ -73,24 +73,54 @@ def _write_csv(self, file_name: str, rows: list[dict], fieldnames: list[str], mo

file_path = Path(self.output_dir) / f'{file_name}.csv'
write_header = True
effective_fieldnames = list(fieldnames)

if mode == 'a' and file_path.exists():
# if appending and file exists, do not write header again
write_header = False
# Read existing fieldnames to maintain schema consistency
with file_path.open('r', newline='', encoding='utf-8') as f:
reader = csv.reader(f)
existing_fields = next(reader, [])
if existing_fields:
# Union: existing fields first, then any new fields not yet present
merged = list(existing_fields)
for fn in fieldnames:
if fn not in merged:
merged.append(fn)
effective_fieldnames = merged

with file_path.open(mode, newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames, restval=None)
writer = csv.DictWriter(f, fieldnames=effective_fieldnames, restval='default')
if write_header:
writer.writeheader()
writer.writerows(rows)

print(f"result has been saved: {file_path}")

def add_tag(self, tag_name: str, tag_value) -> None:
"""Set a persistent tag that will be automatically included in all subsequent rows.

Call this before ``single_COBA_data_add`` (or ``add_row``) to attach
an extra labeled field to every row recorded afterward.
``fieldnames`` is updated immediately so the column appears in the CSV.
"""
self._tags[tag_name] = tag_value
if tag_name not in self.fieldnames:
self.fieldnames.append(tag_name)

def add_row(self, row: dict) -> None:
"""Add a generic row (dict). New keys will be added to fieldnames."""
for key in row.keys():
"""Add a generic row (dict). New keys will be added to fieldnames.

Active tags (set via ``add_tag``) are merged into the row automatically;
explicit values in *row* take precedence over tag values.
"""
# Merge active tags first, then let row values override
merged_row = dict(self._tags)
merged_row.update(row)
for key in merged_row.keys():
if key not in self.fieldnames:
self.fieldnames.append(key)
self.rows.append(row)
self.rows.append(merged_row)

def single_COBA_data_add(self, operator: str,
data_type: str,
Expand Down Expand Up @@ -233,3 +263,94 @@ def dump_jax_ir(func, args=(), kwargs=None, prefix="dump"):
print(f"[*] HLO saved to: {hlo_path}")

return jaxpr_path, hlo_path

def generate_params(
dis_type: str = 'log',
_N: int = 4000,
limit_gb: float = 24,
target_samples: int = 500,
data_size: int = 4,
scale_max: int = 2000,
conn_max: int = 4000,
) -> list:
"""
Generates a list of valid (scale, conn_num) parameter states within VRAM limits.

Boundary conditions
-------------------
- matrix_memory_bytes = conn_num * scale * _N * data_size * 2 <= limit_bytes
- conn_num <= _N * scale

Parameters
----------
dis_type : str
Sampling strategy: 'uniform' (linear grid), 'log' (geometric grid),
or 'monte_carlo' (random sampling).
_N : int
Number of neurons per scale unit.
limit_gb : float
VRAM limit in gigabytes.
target_samples : int
Approximate number of valid states to generate (actual count ≈ ±50).
data_size : int
Bytes per element (4 for float32/int32, 1 for bool/int8).
scale_max : int
Upper bound of scale search range.
conn_max : int
Upper bound of conn_num search range.

Returns
-------
list of (scale, conn_num) tuples, sorted by memory footprint.
"""
import numpy as np

limit_bytes = limit_gb * (1024 ** 3)

def is_valid(s, c):
matrix_memory_bytes = c * s * _N * data_size * 2
return matrix_memory_bytes <= limit_bytes and c <= _N * s

if dis_type == 'monte_carlo':
valid_states = set()
while len(valid_states) < target_samples:
s = int(np.random.uniform(1, scale_max + 1))
c = int(np.random.uniform(1, conn_max + 1))
if is_valid(s, c):
valid_states.add((s, c))
sorted_states = sorted(list(valid_states), key=lambda state: state[0] * state[1])
print(f"Generated {len(sorted_states)} valid parameter states under {limit_gb}GB boundary.")
return sorted_states

# For grid-based methods: the valid region is a curved hyperbolic area;
# ~3x oversampling of grid points relative to target gives ~±50 accuracy after filtering.
grid_res = int(np.sqrt(target_samples * 3))

if dis_type == 'uniform':
scales_raw = np.unique(np.linspace(1, scale_max, num=grid_res, dtype=int))
conn_nums_raw = np.unique(np.linspace(1, conn_max, num=grid_res, dtype=int))
elif dis_type == 'log':
scales_raw = np.unique(np.geomspace(1, scale_max, num=grid_res, dtype=int))
conn_nums_raw = np.unique(np.geomspace(1, conn_max, num=grid_res, dtype=int))
else:
raise ValueError(f"Unknown dis_type: '{dis_type}'. Choose from 'uniform', 'log', 'monte_carlo'.")

valid_states = [
(int(s), int(c))
for s in scales_raw
for c in conn_nums_raw
if is_valid(s, c)
]
valid_states.sort(key=lambda state: state[0] * state[1])
print(f"Generated {len(valid_states)} valid parameter states under {limit_gb}GB boundary.")
return valid_states

def memory_limit( conn_nums, scale:int ,_N:int = 4000 , limit: int = 16, data_type: str = 'float'):
if data_type == 'float': data_size = 4
if data_type == 'int': data_size = 4
if data_type == 'bool': data_size = 1
if data_type == 'char': data_size = 1
if conn_nums * data_size * scale * _N * 2 > limit * (1024 ** 3):
return True
else:
return False
12 changes: 6 additions & 6 deletions dev/fcn/COBA_2005_binary_fcnmm_CsvOuput.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def benchmark_post_conn(batch_size=16, conn_num=None, conn_prob=None, data_type=
# scale=100, size=400000, time = 20.511342525482178 s, firing rate = 59.44398880004883 Hz
print('Benchmarking post-synaptic connection updates...')

import CsvOutput as rp
import dev.fcn.BenchmarkTools as BT

backends_to_use = [backend] if backend is not None else backends

Expand All @@ -90,7 +90,7 @@ def benchmark_post_conn(batch_size=16, conn_num=None, conn_prob=None, data_type=
use_conn_nums = False
probs_to_use = [conn_prob] if conn_prob is not None else probs

csv_recorder = rp.CSV_record('binary_post', 'fcnmm', 'coba', duration=duration, conn=conn_num)
csv_recorder = BT.CSV_record('binary_post', 'fcnmm', 'coba', duration=duration, conn=conn_num)

for back in backends_to_use:
brainevent.config.set_backend('gpu', back)
Expand Down Expand Up @@ -160,7 +160,7 @@ def benchmark_post_conn(batch_size=16, conn_num=None, conn_prob=None, data_type=

def benchmark_pre_conn(batch_size=16, conn_num=None, conn_prob=None, data_type='binary', duration=1e3 * u.ms, homo: bool = True, backend: str | None = None, probs_or_conn='conn'):
print('Benchmarking pre-synaptic connection updates...')
import CsvOutput as rp
import dev.fcn.BenchmarkTools as BT

backends_to_use = [backend] if backend is not None else backends

Expand All @@ -171,7 +171,7 @@ def benchmark_pre_conn(batch_size=16, conn_num=None, conn_prob=None, data_type='
use_conn_nums = False
probs_to_use = [conn_prob] if conn_prob is not None else probs

csv_recorder = rp.CSV_record('binary_pre', 'fcnmm', 'coba', duration=duration, conn=conn_num)
csv_recorder = BT.CSV_record('binary_pre', 'fcnmm', 'coba', duration=duration, conn=conn_num)

for back in backends_to_use:
brainevent.config.set_backend('gpu', back)
Expand Down Expand Up @@ -241,7 +241,7 @@ def benchmark_pre_conn(batch_size=16, conn_num=None, conn_prob=None, data_type='

def run_benchmark(batch_size, conn_num=None, mode='post', homo: bool = True, backend: str | None = None):

import CsvOutput as rp
import dev.fcn.BenchmarkTools as BT

conn_nums_to_use = [conn_num] if conn_num is not None else conn_nums
dur = 1e3 * u.ms if mode == 'post' else 1e2 * u.ms
Expand All @@ -256,7 +256,7 @@ def run_benchmark(batch_size, conn_num=None, mode='post', homo: bool = True, bac
kernel = 'warp' if cn <= 32 else 'basic'

# CSV recorder for this configuration
csv_recorder = rp.CSV_record(f'binary_bs{batch_size}_conn{cn}', 'fcnmm', 'benchmark', duration=dur, conn=cn)
csv_recorder = BT.CSV_record(f'binary_bs{batch_size}_conn{cn}', 'fcnmm', 'benchmark', duration=dur, conn=cn)

for back in backends_to_use:
brainevent.config.set_backend('gpu', back)
Expand Down
36 changes: 15 additions & 21 deletions dev/fcn/COBA_2005_binary_fcnmv_CsvOuput.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,12 @@
from COBA_2005_benchmark import make_simulation_run


scales = [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]
backends = ['cuda_raw','jax_raw']
scales = [1140]
backends = ['cuda_raw']

conn_nums = [20, 40, 80, 160, 320, 640, ]
conn_nums = [686]

probs = [0.001, 0.004, 0.016 ,0.064, 0.128, 0.256]


def memory_limit( conn_nums, scale:int ,_N:int = 4000 , limit: int = 16):
if conn_nums * 4 * scale * _N > limit * (1024 ** 3):
return True
else:
return False
probs = [0.001, 0.004, 0.016 ,0.064, 0.128, 0.256, 0.512]


def benchmark_post_conn(
Expand All @@ -65,7 +58,7 @@ def benchmark_post_conn(
probs_or_conn='conn',
_N : int = 4000
):
import CsvOutput as RP
import dev.fcn.BenchmarkTools as BT

print('Benchmarking post-synaptic connection updates...')

Expand All @@ -78,7 +71,7 @@ def benchmark_post_conn(
use_conn_nums = False
probs_to_use = [conn_prob] if conn_prob is not None else probs

csv_recorder = RP.CSV_record('binary_post', 'fcnmv', 'coba', duration=duration, conn=conn_num)
csv_recorder = BT.CSV_record('binary_post', 'fcnmv', 'coba', duration=duration, conn=conn_num)

for back in backends_to_use:
brainevent.config.set_backend('gpu', back)
Expand Down Expand Up @@ -122,7 +115,7 @@ def benchmark_post_conn(
for s in scales:
actual_conn_num = int(s * prob * _N)
if actual_conn_num < 1 : actual_conn_num = 1
if memory_limit(actual_conn_num, scale=s): continue
if BT.memory_limit(actual_conn_num, scale=s): continue
try:
run = make_simulation_run(
scale=s,
Expand All @@ -145,7 +138,7 @@ def benchmark_post_conn(
print(f' [Error] scale={s}, conn_num={actual_conn_num}: {e}')
continue

csv_recorder.record_finish('post_cuda-jax_great_scale')
csv_recorder.record_finish('float_mode_single_point_with_spconn-scale')

def benchmark_pre_conn(
conn_num=None,
Expand All @@ -158,7 +151,7 @@ def benchmark_pre_conn(
_N : int = 4000
):
print('Benchmarking pre-synaptic connection updates...')
import CsvOutput as RP
import dev.fcn.BenchmarkTools as BT

backends_to_use = [backend] if backend is not None else backends

Expand All @@ -169,7 +162,7 @@ def benchmark_pre_conn(
use_conn_nums = False
probs_to_use = [conn_prob] if conn_prob is not None else probs

csv_recorder = RP.CSV_record('binary_pre', 'fcnmv', 'coba', duration=duration, conn=conn_num)
csv_recorder = BT.CSV_record('binary_pre', 'fcnmv', 'coba', duration=duration, conn=conn_num)

for back in backends_to_use:
brainevent.config.set_backend('gpu', back)
Expand Down Expand Up @@ -213,7 +206,7 @@ def benchmark_pre_conn(
for s in scales:
actual_conn_num = int(s * prob * _N) # non-linear: conn_num scales with network size
if actual_conn_num < 1 : actual_conn_num = 1
if memory_limit(actual_conn_num, scale=s): continue
if BT.memory_limit(actual_conn_num, scale=s): continue
try:
run = make_simulation_run(
scale=s,
Expand All @@ -240,7 +233,8 @@ def benchmark_pre_conn(


if __name__ == '__main__':
#benchmark_post_conn(conn_num=80, data_type='binary', duration=1e4 * u.ms, backend='jax_raw')
benchmark_post_conn(data_type='binary', duration=1e2 * u.ms, probs_or_conn='prob')
benchmark_post_conn(data_type='compact', duration=1e2 * u.ms, probs_or_conn='conn')
benchmark_post_conn(data_type='binary', duration=1e2 * u.ms, probs_or_conn='conn')
#benchmark_pre_conn(conn_num=80, data_type='bitpack', duration=1e3 * u.ms)
#benchmark_pre_conn(data_type='binary',duration=1e3 * u.ms,)
#benchmark_pre_conn(data_type='binary',duration=1e3 * u.ms, probs_or_conn='conn')
#benchmark_pre_conn(data_type='bitpack',duration=1e3 * u.ms, probs_or_conn='conn')
Loading
Loading