Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions acc/components/optics/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from math import ceil, sqrt, tanh
from numba import cuda, float32, void
from numba import cuda, float32, int64, void
from time import time

from mcni.AbstractComponent import AbstractComponent
Expand Down Expand Up @@ -131,20 +131,26 @@ def propagate(

@cuda.jit(void(float32, float32, float32, float32, float32,
float32, float32, float32, float32, float32,
float32[:, :]))
float32[:, :], int64))
def process_kernel(
ww, hh, hw1, hh1, l,
R0, Qc, alpha, m, W,
neutrons
neutrons, batch_size
):
x = cuda.grid(1)
if x < len(neutrons):
neutron_count = len(neutrons)
work_unit = cuda.grid(1)
neutron_start = work_unit * batch_size
neutron_end = min(neutron_count, neutron_start + batch_size)
for neutron_index in range(neutron_start, neutron_end):
propagate(
ww, hh, hw1, hh1, l,
R0, Qc, alpha, m, W,
neutrons[x]
neutrons[neutron_index]
)
return


thread_count_target = 1e5
threads_per_block = 512


def call_process(
Expand All @@ -153,13 +159,15 @@ def call_process(
in_neutrons
):
neutron_count = len(in_neutrons)
threads_per_block = 512
number_of_blocks = ceil(neutron_count / threads_per_block)
print("{} blocks, {} threads".format(number_of_blocks, threads_per_block))
process_kernel[number_of_blocks, threads_per_block](
thread_count = min(thread_count_target, neutron_count)
block_count = ceil(thread_count / threads_per_block)
neutrons_per_thread = ceil(neutron_count / thread_count)
print("{} blocks, {} threads per block, {} neutrons per thread".format(
block_count, threads_per_block, neutrons_per_thread))
process_kernel[block_count, threads_per_block](
ww, hh, hw1, hh1, l,
R0, Qc, alpha, m, W,
in_neutrons
in_neutrons, neutrons_per_thread
)
cuda.synchronize()

Expand Down