diff --git a/acc/components/optics/guide.py b/acc/components/optics/guide.py index b9c0ea63..a2969f97 100644 --- a/acc/components/optics/guide.py +++ b/acc/components/optics/guide.py @@ -5,7 +5,7 @@ import numpy as np from math import ceil, sqrt, tanh -from numba import cuda, float32, void +from numba import cuda, float32, int64, void from time import time from mcni.AbstractComponent import AbstractComponent @@ -131,20 +131,26 @@ def propagate( @cuda.jit(void(float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, - float32[:, :])) + float32[:, :], int64)) def process_kernel( ww, hh, hw1, hh1, l, R0, Qc, alpha, m, W, - neutrons + neutrons, batch_size ): - x = cuda.grid(1) - if x < len(neutrons): + neutron_count = len(neutrons) + work_unit = cuda.grid(1) + neutron_start = work_unit * batch_size + neutron_end = min(neutron_count, neutron_start + batch_size) + for neutron_index in range(neutron_start, neutron_end): propagate( ww, hh, hw1, hh1, l, R0, Qc, alpha, m, W, - neutrons[x] + neutrons[neutron_index] ) - return + + +thread_count_target = 1e5 +threads_per_block = 512 def call_process( @@ -153,13 +159,15 @@ def call_process( in_neutrons ): neutron_count = len(in_neutrons) - threads_per_block = 512 - number_of_blocks = ceil(neutron_count / threads_per_block) - print("{} blocks, {} threads".format(number_of_blocks, threads_per_block)) - process_kernel[number_of_blocks, threads_per_block]( + thread_count = min(thread_count_target, neutron_count) + block_count = ceil(thread_count / threads_per_block) + neutrons_per_thread = ceil(neutron_count / thread_count) + print("{} blocks, {} threads per block, {} neutrons per thread".format( + block_count, threads_per_block, neutrons_per_thread)) + process_kernel[block_count, threads_per_block]( ww, hh, hw1, hh1, l, R0, Qc, alpha, m, W, - in_neutrons + in_neutrons, neutrons_per_thread ) cuda.synchronize()