policy = Pi0Inference(checkpoint=weights, num_views=4, chunk_size=50)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "pi0_infer.py", line 1341, in __init__
self.record_infer_graph()
File "pi0_infer.py", line 1349, in record_infer_graph
self.record_run()
File "pi0_infer.py", line 1345, in record_run
pi0_model(self.weights, self.buffers, self.num_views)
File "pi0_infer.py", line 1236, in pi0_model
transformer_decoder(weights, buffers, encoder_seq_len)
File "pi0_infer.py", line 1195, in transformer_decoder
matmul_k8_256_n_softmax_mask0(
File "pi0_infer.py", line 1104, in matmul_k8_256_n_softmax_mask0
softmax_kernel_mask0[((total_queries + 3) // 4,)](out,
File ".venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 347, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 591, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
^^^^^^^
File ".venv/lib/python3.12/site-packages/triton/compiler/compiler.py", line 408, in _init_handles
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
^^^^^^^^^^^^^^^^^^^
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
It looks like the current implementation supports
num_views <=3.When we set
num_views = 4inPi0Inference(), an illegal memory access error occurs.Is this restriction (
num_views <= 3) intended?Would it be possible to support cases where
num_views > 3?