11import torch
2- from typing import Optional , Tuple , Dict , List
2+ from typing import Optional , Tuple , Dict , List , Any
33import logging
44from clt .config import CLTConfig
55from torch .distributed import ProcessGroup
@@ -26,9 +26,10 @@ def _compute_mask(x: torch.Tensor, k_per_token: int, x_for_ranking: Optional[tor
2626
2727 if k_total_batch > 0 :
2828 _ , flat_indices = torch .topk (ranking_flat , k_total_batch , sorted = False )
29- mask_flat = torch .zeros_like (x_flat , dtype = torch .bool )
30- mask_flat [flat_indices ] = True
31- mask = mask_flat .view_as (x )
29+ # Optimized mask creation - avoid individual indexing
30+ mask = torch .zeros (x_flat .numel (), dtype = torch .bool , device = x .device )
31+ mask [flat_indices ] = True
32+ mask = mask .view_as (x )
3233 else :
3334 mask = torch .zeros_like (x , dtype = torch .bool )
3435
@@ -118,6 +119,7 @@ def _compute_mask(x: torch.Tensor, k_float: float, x_for_ranking: Optional[torch
118119
119120 if k_per_token > 0 :
120121 _ , topk_indices_per_row = torch .topk (ranking_tensor_to_use , k_per_token , dim = - 1 , sorted = False )
122+ # Use scatter_ for efficient mask creation
121123 mask = torch .zeros_like (x , dtype = torch .bool )
122124 mask .scatter_ (- 1 , topk_indices_per_row , True )
123125 else :
@@ -231,6 +233,7 @@ def _apply_batch_topk_helper(
231233 dtype : torch .dtype ,
232234 rank : int ,
233235 process_group : Optional [ProcessGroup ],
236+ profiler : Optional [Any ] = None ,
234237) -> Dict [int , torch .Tensor ]:
235238 """Helper to apply BatchTopK globally across concatenated layer pre-activations."""
236239
@@ -304,17 +307,42 @@ def _apply_batch_topk_helper(
304307
305308 if world_size > 1 :
306309 if rank == 0 :
307- local_mask = BatchTopK ._compute_mask (
308- concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
309- )
310+ if profiler :
311+ with profiler .timer ("batchtopk_compute_mask" ) as timer :
312+ local_mask = BatchTopK ._compute_mask (
313+ concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
314+ )
315+ if hasattr (timer , 'elapsed' ):
316+ profiler .record ("batchtopk_compute_mask" , timer .elapsed )
317+ else :
318+ local_mask = BatchTopK ._compute_mask (
319+ concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
320+ )
310321 mask .copy_ (local_mask )
311- dist_ops .broadcast (mask , src = 0 , group = process_group )
322+
323+ if hasattr (profiler , 'dist_profiler' ) and profiler .dist_profiler :
324+ with profiler .dist_profiler .profile_op ("batchtopk_broadcast" ):
325+ dist_ops .broadcast (mask , src = 0 , group = process_group )
326+ else :
327+ dist_ops .broadcast (mask , src = 0 , group = process_group )
312328 else :
313- dist_ops .broadcast (mask , src = 0 , group = process_group )
329+ if hasattr (profiler , 'dist_profiler' ) and profiler .dist_profiler :
330+ with profiler .dist_profiler .profile_op ("batchtopk_broadcast" ):
331+ dist_ops .broadcast (mask , src = 0 , group = process_group )
332+ else :
333+ dist_ops .broadcast (mask , src = 0 , group = process_group )
314334 else :
315- mask = BatchTopK ._compute_mask (
316- concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
317- )
335+ if profiler :
336+ with profiler .timer ("batchtopk_compute_mask" ) as timer :
337+ mask = BatchTopK ._compute_mask (
338+ concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
339+ )
340+ if hasattr (timer , 'elapsed' ):
341+ profiler .record ("batchtopk_compute_mask" , timer .elapsed )
342+ else :
343+ mask = BatchTopK ._compute_mask (
344+ concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
345+ )
318346
319347 activated_concatenated = concatenated_preactivations_original * mask .to (dtype )
320348
@@ -336,6 +364,7 @@ def _apply_token_topk_helper(
336364 dtype : torch .dtype ,
337365 rank : int ,
338366 process_group : Optional [ProcessGroup ],
367+ profiler : Optional [Any ] = None ,
339368) -> Dict [int , torch .Tensor ]:
340369 """Helper to apply TokenTopK globally across concatenated layer pre-activations."""
341370 world_size = dist_ops .get_world_size (process_group )
@@ -408,19 +437,46 @@ def _apply_token_topk_helper(
408437
409438 if world_size > 1 :
410439 if rank == 0 :
411- local_mask = TokenTopK ._compute_mask (
412- concatenated_preactivations_original ,
413- k_val_float ,
414- concatenated_preactivations_normalized ,
415- )
440+ if profiler :
441+ with profiler .timer ("topk_compute_mask" ) as timer :
442+ local_mask = TokenTopK ._compute_mask (
443+ concatenated_preactivations_original ,
444+ k_val_float ,
445+ concatenated_preactivations_normalized ,
446+ )
447+ if hasattr (timer , 'elapsed' ):
448+ profiler .record ("topk_compute_mask" , timer .elapsed )
449+ else :
450+ local_mask = TokenTopK ._compute_mask (
451+ concatenated_preactivations_original ,
452+ k_val_float ,
453+ concatenated_preactivations_normalized ,
454+ )
416455 mask .copy_ (local_mask )
417- dist_ops .broadcast (mask , src = 0 , group = process_group )
456+
457+ if hasattr (profiler , 'dist_profiler' ) and profiler .dist_profiler :
458+ with profiler .dist_profiler .profile_op ("topk_broadcast" ):
459+ dist_ops .broadcast (mask , src = 0 , group = process_group )
460+ else :
461+ dist_ops .broadcast (mask , src = 0 , group = process_group )
418462 else :
419- dist_ops .broadcast (mask , src = 0 , group = process_group )
463+ if hasattr (profiler , 'dist_profiler' ) and profiler .dist_profiler :
464+ with profiler .dist_profiler .profile_op ("topk_broadcast" ):
465+ dist_ops .broadcast (mask , src = 0 , group = process_group )
466+ else :
467+ dist_ops .broadcast (mask , src = 0 , group = process_group )
420468 else :
421- mask = TokenTopK ._compute_mask (
422- concatenated_preactivations_original , k_val_float , concatenated_preactivations_normalized
423- )
469+ if profiler :
470+ with profiler .timer ("topk_compute_mask" ) as timer :
471+ mask = TokenTopK ._compute_mask (
472+ concatenated_preactivations_original , k_val_float , concatenated_preactivations_normalized
473+ )
474+ if hasattr (timer , 'elapsed' ):
475+ profiler .record ("topk_compute_mask" , timer .elapsed )
476+ else :
477+ mask = TokenTopK ._compute_mask (
478+ concatenated_preactivations_original , k_val_float , concatenated_preactivations_normalized
479+ )
424480
425481 activated_concatenated = concatenated_preactivations_original * mask .to (dtype )
426482
0 commit comments