diff --git a/src/tilegym/ops/cutile/mhc.py b/src/tilegym/ops/cutile/mhc.py index a5fe2a7..41c89f5 100644 --- a/src/tilegym/ops/cutile/mhc.py +++ b/src/tilegym/ops/cutile/mhc.py @@ -511,13 +511,14 @@ def mhc_apply_residual( @ct.kernel def mhc_sinkhorn_kernel( Y, + m: ct.Constant[int], n: ct.Constant[int], ): """Sinkhorn-Knopp normalization for residual block (in-place on Y).""" row = ct.bid(0) total = n * n - mat = ct.load(Y, index=(row, 0), shape=(1, total)) - mat = ct.reshape(mat, (n, n)) + mat = ct.load(Y, index=(row, 0), shape=(m, total)) + mat = ct.reshape(mat, (m, n, n)) mat = ct.astype(mat, ct.float32) mat = ct.exp2(mat * LOG2E) @@ -527,7 +528,7 @@ def mhc_sinkhorn_kernel( col_sum = ct.sum(mat, axis=0, keepdims=True) mat = ct.truediv(mat, col_sum) - mat = ct.reshape(mat, (1, total)) + mat = ct.reshape(mat, (m, total)) mat = ct.astype(mat, Y.dtype) ct.store(Y, index=(row, 0), tile=mat) @@ -535,13 +536,13 @@ def mhc_sinkhorn_kernel( @register_impl("mhc_sinkhorn", backend="cutile") def mhc_sinkhorn( y: torch.Tensor, - n: int, + n: int, tileM: int = 8, **kwargs, ): y = y.contiguous() M, _ = y.shape y_view = y.narrow(1, 2 * n, n * n) - grid = (M,) + grid = (ct.cdiv(M, tileM),) ct.launch( torch.cuda.current_stream(), grid,